diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 470e2e5a5..f17941405 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -49,7 +49,7 @@ jobs: - name: build docs run: | cd docs - RTD_BUILD=1 make html SPHINXOPTS="-W --keep-going" + DOCS_BUILD=1 make html SPHINXOPTS="-W --keep-going" # set environment variable `DOCS_VERSION_DIR` to either the pr-branch name, "dev", or the release version tag - name: set output pr diff --git a/fastplotlib/graphics/_base.py b/fastplotlib/graphics/_base.py index a4f3e9a67..6d369782d 100644 --- a/fastplotlib/graphics/_base.py +++ b/fastplotlib/graphics/_base.py @@ -160,6 +160,7 @@ def __init__( self._alpha_mode = AlphaMode(alpha_mode) self._visible = Visible(visible) self._block_events = False + self._block_handlers = list() self._axes: Axes = None @@ -242,6 +243,11 @@ def block_events(self) -> bool: def block_events(self, value: bool): self._block_events = value + @property + def block_handlers(self) -> list: + """Used to block event handlers for a graphic and prevent recursion.""" + return self._block_handlers + @property def world_object(self) -> pygfx.WorldObject: """Associated pygfx WorldObject. Always returns a proxy, real object cannot be accessed directly.""" @@ -370,6 +376,9 @@ def _handle_event(self, callback, event: pygfx.Event): if self.block_events: return + if callback in self._block_handlers: + return + if event.type in self._features: # for feature events event._target = self.world_object diff --git a/fastplotlib/graphics/features/_base.py b/fastplotlib/graphics/features/_base.py index 779310476..96e2dd102 100644 --- a/fastplotlib/graphics/features/_base.py +++ b/fastplotlib/graphics/features/_base.py @@ -320,7 +320,7 @@ def __repr__(self): def block_reentrance(set_value): # decorator to block re-entrant set_value methods # useful when creating complex, circular, bidirectional event graphs - def set_value_wrapper(self: GraphicFeature, graphic_or_key, value): + def set_value_wrapper(self: GraphicFeature, graphic_or_key, value, **kwargs): """ wraps GraphicFeature.set_value @@ -336,7 +336,7 @@ def set_value_wrapper(self: GraphicFeature, graphic_or_key, value): try: # block re-execution of set_value until it has *fully* finished executing self._reentrant_block = True - set_value(self, graphic_or_key, value) + set_value(self, graphic_or_key, value, **kwargs) except Exception as exc: # raise original exception raise exc # set_value has raised. The line above and the lines 2+ steps below are probably more relevant! diff --git a/fastplotlib/graphics/features/_selection_features.py b/fastplotlib/graphics/features/_selection_features.py index 9b30dd70c..1f049f0cb 100644 --- a/fastplotlib/graphics/features/_selection_features.py +++ b/fastplotlib/graphics/features/_selection_features.py @@ -118,7 +118,7 @@ def axis(self) -> str: return self._axis @block_reentrance - def set_value(self, selector, value: Sequence[float]): + def set_value(self, selector, value: Sequence[float], *, change: str = "full"): """ Set start, stop range of selector @@ -182,7 +182,9 @@ def set_value(self, selector, value: Sequence[float]): if len(self._event_handlers) < 1: return - event = GraphicFeatureEvent(self._property_name, {"value": self.value}) + event = GraphicFeatureEvent( + self._property_name, {"value": self.value, "change": change} + ) event.get_selected_indices = selector.get_selected_indices event.get_selected_data = selector.get_selected_data diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index 1eaf54bb6..9a62af2bc 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -158,19 +158,26 @@ def __init__( self._interpolation = ImageInterpolation(interpolation) # set map to None for RGB images - if self._data.value.ndim > 2: + if self._data.value.ndim == 3: self._cmap = None + self._cmap_interpolation = None _map = None - else: + + elif self._data.value.ndim == 2: # use TextureMap for grayscale images self._cmap = ImageCmap(cmap) self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) - _map = pygfx.TextureMap( self._cmap.texture, filter=self._cmap_interpolation.value, wrap="clamp-to-edge", ) + else: + raise ValueError( + f"ImageGraphic `data` must have 2 dimensions for grayscale images, or 3 dimensions for RGB(A) images.\n" + f"You have passed a a data array with: {self._data.value.ndim} dimensions, " + f"and of shape: {self._data.value.shape}" + ) # one common material is used for every Texture chunk self._material = pygfx.ImageBasicMaterial( @@ -223,8 +230,6 @@ def cmap(self) -> str | None: if self._cmap is not None: return self._cmap.value - return None - @cmap.setter def cmap(self, name: str): if self.data.value.ndim > 2: @@ -259,9 +264,10 @@ def interpolation(self, value: str): self._interpolation.set_value(self, value) @property - def cmap_interpolation(self) -> str: - """cmap interpolation method""" - return self._cmap_interpolation.value + def cmap_interpolation(self) -> str | None: + """cmap interpolation method, 'linear' or 'nearest'. `None` if image is RGB(A)""" + if self._cmap_interpolation is not None: + return self._cmap_interpolation.value @cmap_interpolation.setter def cmap_interpolation(self, value: str): diff --git a/fastplotlib/graphics/image_volume.py b/fastplotlib/graphics/image_volume.py index db616b30d..b8bed454e 100644 --- a/fastplotlib/graphics/image_volume.py +++ b/fastplotlib/graphics/image_volume.py @@ -211,16 +211,28 @@ def __init__( self._interpolation = ImageInterpolation(interpolation) - # TODO: I'm assuming RGB volume images aren't supported??? - # use TextureMap for grayscale images - self._cmap = ImageCmap(cmap) - self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) - - self._texture_map = pygfx.TextureMap( - self._cmap.texture, - filter=self._cmap_interpolation.value, - wrap="clamp-to-edge", - ) + if self._data.value.ndim == 4: + # set map to None for RGB image volumes + self._cmap = None + self._texture_map = None + self._cmap_interpolation = None + + elif self._data.value.ndim == 3: + # use TextureMap for grayscale images + self._cmap = ImageCmap(cmap) + self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) + self._texture_map = pygfx.TextureMap( + self._cmap.texture, + filter=self._cmap_interpolation.value, + wrap="clamp-to-edge", + ) + else: + raise ValueError( + f"ImageVolumeGraphic `data` must have 3 dimensions for grayscale images, " + f"or 4 dimensions for RGB(A) images.\n" + f"You have passed a a data array with: {self._data.value.ndim} dimensions, " + f"and of shape: {self._data.value.shape}" + ) self._plane = VolumeSlicePlane(plane) self._threshold = VolumeIsoThreshold(threshold) @@ -282,9 +294,10 @@ def mode(self, mode: str): self._mode.set_value(self, mode) @property - def cmap(self) -> str: + def cmap(self) -> str | None: """Get or set colormap name""" - return self._cmap.value + if self._cmap is not None: + return self._cmap.value @cmap.setter def cmap(self, name: str): @@ -318,9 +331,10 @@ def interpolation(self, value: str): self._interpolation.set_value(self, value) @property - def cmap_interpolation(self) -> str: + def cmap_interpolation(self) -> str | None: """Get or set the cmap interpolation method""" - return self._cmap_interpolation.value + if self._cmap_interpolation is not None: + return self._cmap_interpolation.value @cmap_interpolation.setter def cmap_interpolation(self, value: str): diff --git a/fastplotlib/graphics/selectors/_linear_region.py b/fastplotlib/graphics/selectors/_linear_region.py index 70a8dffa8..8a8583ae9 100644 --- a/fastplotlib/graphics/selectors/_linear_region.py +++ b/fastplotlib/graphics/selectors/_linear_region.py @@ -472,9 +472,9 @@ def _move_graphic(self, move_info: MoveInfo): if move_info.source == self._edges[0]: # change only left or bottom bound new_min = min(cur_min + delta, cur_max) - self._selection.set_value(self, (new_min, cur_max)) + self._selection.set_value(self, (new_min, cur_max), change="min") elif move_info.source == self._edges[1]: # change only right or top bound new_max = max(cur_max + delta, cur_min) - self._selection.set_value(self, (cur_min, new_max)) + self._selection.set_value(self, (cur_min, new_max), change="max") diff --git a/fastplotlib/graphics/utils.py b/fastplotlib/graphics/utils.py index 6be5aefc4..f32d80809 100644 --- a/fastplotlib/graphics/utils.py +++ b/fastplotlib/graphics/utils.py @@ -1,13 +1,16 @@ from contextlib import contextmanager +from typing import Callable, Iterable from ._base import Graphic @contextmanager -def pause_events(*graphics: Graphic): +def pause_events(*graphics: Graphic, event_handlers: Iterable[Callable] = None): """ Context manager for pausing Graphic events. + Optionally pass in only specific event handlers which are blocked. Other events for the graphic will not be blocked. + Examples -------- @@ -30,8 +33,14 @@ def pause_events(*graphics: Graphic): original_vals = [g.block_events for g in graphics] for g in graphics: - g.block_events = True + if event_handlers is not None: + g.block_handlers.extend([e for e in event_handlers]) + else: + g.block_events = True yield for g, value in zip(graphics, original_vals): - g.block_events = value + if event_handlers is not None: + g.block_handlers.clear() + else: + g.block_events = value diff --git a/fastplotlib/layouts/_figure.py b/fastplotlib/layouts/_figure.py index 8fd5dc666..59f93b15e 100644 --- a/fastplotlib/layouts/_figure.py +++ b/fastplotlib/layouts/_figure.py @@ -554,7 +554,7 @@ def show_tooltips(self, val: bool): if val: # register all graphics - for subplot in self: + for subplot in self._subplots.ravel(): for graphic in subplot.graphics: self._tooltip_manager.register(graphic) @@ -572,7 +572,7 @@ def _render(self, draw=True): # call the animation functions before render self._call_animate_functions(self._animate_funcs_pre) - for subplot in self: + for subplot in self._subplots.ravel(): subplot._render() # overlay render pass @@ -639,14 +639,14 @@ def show( sidecar_kwargs = dict() # flip y-axis if ImageGraphics are present - for subplot in self: + for subplot in self._subplots.ravel(): for g in subplot.graphics: if isinstance(g, ImageGraphic): subplot.camera.local.scale_y *= -1 break if autoscale: - for subplot in self: + for subplot in self._subplots.ravel(): if maintain_aspect is None: _maintain_aspect = subplot.camera.maintain_aspect else: @@ -655,7 +655,7 @@ def show( # set axes visibility if False if not axes_visible: - for subplot in self: + for subplot in self._subplots.ravel(): subplot.axes.visible = False # parse based on canvas type @@ -679,15 +679,15 @@ def show( elif self.canvas.__class__.__name__ == "OffscreenRenderCanvas": # for test and docs gallery screenshots self._fpl_reset_layout() - for subplot in self: + for subplot in self._subplots.ravel(): subplot.axes.update_using_camera() # render call is blocking only on github actions for some reason, # but not for rtd build, this is a workaround # for CI tests, the render call works if it's in test_examples # but it is necessary for the gallery images too so that's why this check is here - if "RTD_BUILD" in os.environ.keys(): - if os.environ["RTD_BUILD"] == "1": + if "DOCS_BUILD" in os.environ.keys(): + if os.environ["DOCS_BUILD"] == "1": self._render() else: # assume GLFW @@ -803,7 +803,7 @@ def clear_animations(self, removal: str = None): def clear(self): """Clear all Subplots""" - for subplot in self: + for subplot in self._subplots.ravel(): subplot.clear() def export_numpy(self, rgb: bool = False) -> np.ndarray: @@ -962,18 +962,20 @@ def __getitem__(self, index: str | int | tuple[int, int]) -> Subplot: return subplot raise IndexError(f"no subplot with given name: {index}") + if isinstance(index, (int, np.integer)): + return self._subplots.ravel()[index] + if isinstance(self.layout, GridLayout): return self._subplots[index[0], index[1]] - return self._subplots[index] + raise TypeError( + f"Can index figure using subplot name, numerical subplot index, or a " + f"tuple[int, int] if the layout is a grid" + ) def __iter__(self): - self._current_iter = iter(range(len(self))) - return self - - def __next__(self) -> Subplot: - pos = self._current_iter.__next__() - return self._subplots.ravel()[pos] + for subplot in self._subplots.ravel(): + yield subplot def __len__(self): """number of subplots""" @@ -988,6 +990,6 @@ def __repr__(self): return ( f"fastplotlib.{self.__class__.__name__}" f" Subplots:\n" - f"\t{newline.join(subplot.__str__() for subplot in self)}" + f"\t{newline.join(subplot.__str__() for subplot in self._subplots.ravel())}" f"\n" ) diff --git a/fastplotlib/layouts/_plot_area.py b/fastplotlib/layouts/_plot_area.py index 01721780c..640f9a85a 100644 --- a/fastplotlib/layouts/_plot_area.py +++ b/fastplotlib/layouts/_plot_area.py @@ -226,7 +226,10 @@ def controller(self, new_controller: str | pygfx.Controller): # pygfx plans on refactoring viewports anyways if self.parent is not None: if self.parent.__class__.__name__.endswith("Figure"): - for subplot in self.parent: + # always use figure._subplots.ravel() in internal fastplotlib code + # otherwise if we use `for subplot in figure`, this could conflict + # with a user's iterator where they are doing `for subplot in figure` !!! + for subplot in self.parent._subplots.ravel(): if subplot.camera in cameras_list: new_controller.register_events(subplot.viewport) subplot._controller = new_controller diff --git a/fastplotlib/tools/_histogram_lut.py b/fastplotlib/tools/_histogram_lut.py index 7507a7ff2..36f840970 100644 --- a/fastplotlib/tools/_histogram_lut.py +++ b/fastplotlib/tools/_histogram_lut.py @@ -6,422 +6,410 @@ import pygfx -from ..utils import subsample_array +from ..utils import subsample_array, RenderQueue from ..graphics import LineGraphic, ImageGraphic, ImageVolumeGraphic, TextGraphic from ..graphics.utils import pause_events from ..graphics._base import Graphic +from ..graphics.features import GraphicFeatureEvent from ..graphics.selectors import LinearRegionSelector -def _get_image_graphic_events(image_graphic: ImageGraphic) -> list[str]: - """Small helper function to return the relevant events for an ImageGraphic""" - events = ["vmin", "vmax"] +def _format_value(value: float): + abs_val = abs(value) + if abs_val < 0.01 or abs_val > 9_999: + return f"{value:.2e}" + else: + return f"{value:.2f}" - if not image_graphic.data.value.ndim > 2: - events.append("cmap") - # if RGB(A), do not add cmap - - return events - - -# TODO: This is a widget, we can think about a BaseWidget class later if necessary class HistogramLUTTool(Graphic): def __init__( self, - data: np.ndarray, - images: ( - ImageGraphic - | ImageVolumeGraphic - | Sequence[ImageGraphic | ImageVolumeGraphic] - ), - nbins: int = 100, - flank_divisor: float = 5.0, + histogram: tuple[np.ndarray, np.ndarray], + images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] | None = None, **kwargs, ): """ - HistogramLUT tool that can be used to control the vmin, vmax of ImageGraphics or ImageVolumeGraphics. - If used to control multiple images or image volumes it is assumed that they share a representation of - the same data, and that their histogram, vmin, and vmax are identical. For example, displaying a - ImageVolumeGraphic and several images that represent slices of the same volume data. + A histogram tool that allows adjusting the vmin, vmax of images. + Also allows changing the cmap LUT for grayscale images and displays a colorbar. Parameters ---------- - data: np.ndarray - - images: ImageGraphic | ImageVolumeGraphic | tuple[ImageGraphic | ImageVolumeGraphic] - - nbins: int, defaut 100. - Total number of bins used in the histogram + histogram: tuple[np.ndarray, np.ndarray] + [frequency, bin_edges], must be 100 bins - flank_divisor: float, default 5.0. - Fraction of empty histogram bins on the tails of the distribution set `np.inf` for no flanks + images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] + the images that are managed by the histogram tool - kwargs: passed to ``Graphic`` + kwargs: + passed to ``Graphic`` """ - super().__init__(**kwargs) - - self._nbins = nbins - self._flank_divisor = flank_divisor - - if isinstance(images, (ImageGraphic, ImageVolumeGraphic)): - images = (images,) - elif isinstance(images, Sequence): - if not all( - [isinstance(ig, (ImageGraphic, ImageVolumeGraphic)) for ig in images] - ): - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - else: - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - self._images = images + super().__init__(**kwargs) - self._data = weakref.proxy(data) + if len(histogram) != 2: + raise TypeError - self._scale_factor: float = 1.0 + self._block_reentrance = False + self._images = list() - hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) + self._bin_centers_flanked = np.zeros(120, dtype=np.float64) + self._freq_flanked = np.zeros(120, dtype=np.float32) - line_data = np.column_stack([hist_scaled, edges_flanked]) + # 100 points for the histogram, 10 points on each side for the flank + line_data = np.column_stack( + [np.zeros(120, dtype=np.float32), np.arange(0, 120)] + ) - self._histogram_line = LineGraphic( - line_data, colors=(0.8, 0.8, 0.8), alpha_mode="solid", offset=(0, 0, -1) + # line that displays the histogram + self._line = LineGraphic( + line_data, colors=(0.8, 0.8, 0.8), alpha_mode="solid", offset=(1, 0, 0) + ) + self._line.world_object.local.scale_x = -1 + + # vmin, vmax selector + self._selector = LinearRegionSelector( + selection=(10, 110), + limits=(0, 119), + size=1.5, + center=0.5, # frequency data are normalized between 0-1 + axis="y", + parent=self._line, ) - bounds = (edges[0] * self._scale_factor, edges[-1] * self._scale_factor) - limits = (edges_flanked[0], edges_flanked[-1]) - size = 120 # since it's scaled to 100 - origin = (hist_scaled.max() / 2, 0) + self._selector.add_event_handler(self._selector_event_handler, "selection") - self._linear_region_selector = LinearRegionSelector( - selection=bounds, - limits=limits, - size=size, - center=origin[0], - axis="y", - parent=self._histogram_line, + self._colorbar = ImageGraphic( + data=np.zeros([120, 2]), interpolation="linear", offset=(1.5, 0, 0) ) - self._vmin = self.images[0].vmin - self._vmax = self.images[0].vmax + # make the colorbar thin + self._colorbar.world_object.local.scale_x = 0.15 + self._colorbar.add_event_handler(self._open_cmap_picker, "click") - # there will be a small difference with the histogram edges so this makes them both line up exactly - self._linear_region_selector.selection = ( - self._vmin * self._scale_factor, - self._vmax * self._scale_factor, + # colorbar ruler + self._ruler = pygfx.Ruler( + end_pos=(0, 119, 0), + alpha_mode="solid", + render_queue=RenderQueue.axes, + tick_side="right", + tick_marker="tick_right", + tick_format=self._ruler_tick_map, + min_tick_distance=10, ) + self._ruler.local.x = 1.75 - vmin_str, vmax_str = self._get_vmin_vmax_str() + # TODO: need to auto-scale using the text so it appears nicely, will do later + self._ruler.visible = False self._text_vmin = TextGraphic( - text=vmin_str, + text="", font_size=16, - offset=(0, 0, 0), anchor="top-left", outline_color="black", outline_thickness=0.5, alpha_mode="solid", ) - + # this is to make sure clicking text doesn't conflict with the selector tool + # since the text appears near the selector tool self._text_vmin.world_object.material.pick_write = False self._text_vmax = TextGraphic( - text=vmax_str, + text="", font_size=16, - offset=(0, 0, 0), anchor="bottom-left", outline_color="black", outline_thickness=0.5, alpha_mode="solid", ) - self._text_vmax.world_object.material.pick_write = False - widget_wo = pygfx.Group() - widget_wo.add( - self._histogram_line.world_object, - self._linear_region_selector.world_object, + # add all the world objects to a pygfx.Group + wo = pygfx.Group() + wo.add( + self._line.world_object, + self._selector.world_object, + self._colorbar.world_object, + self._ruler, self._text_vmin.world_object, self._text_vmax.world_object, ) + self._set_world_object(wo) - self._set_world_object(widget_wo) + # for convenience, a list that stores all the graphics managed by the histogram LUT tool + self._children = [ + self._line, + self._selector, + self._colorbar, + self._text_vmin, + self._text_vmax, + ] - self.world_object.local.scale_x *= -1 + # set histogram + self.histogram = histogram - self._text_vmin.offset = (-120, self._linear_region_selector.selection[0], 0) + # set the images + self.images = images - self._text_vmax.offset = (-120, self._linear_region_selector.selection[1], 0) + def _fpl_add_plot_area_hook(self, plot_area): + self._plot_area = plot_area - self._linear_region_selector.add_event_handler( - self._linear_region_handler, "selection" - ) + for child in self._children: + # need all of them to call the add_plot_area_hook so that events are connected correctly + # example, the linear region selector needs all the canvas events to be connected + child._fpl_add_plot_area_hook(plot_area) - ig_events = _get_image_graphic_events(self.images[0]) + if hasattr(self._plot_area, "size"): + # if it's in a dock area + self._plot_area.size = 80 - for ig in self.images: - ig.add_event_handler(self._image_cmap_handler, *ig_events) + # disable the controller in this plot area + self._plot_area.controller.enabled = False + self._plot_area.auto_scale(maintain_aspect=False) - # colorbar for grayscale images - if self.images[0].cmap is not None: - self._colorbar: ImageGraphic = self._make_colorbar(edges_flanked) - self._colorbar.add_event_handler(self._open_cmap_picker, "click") + # tick text for colorbar ruler doesn't show without this call + self._ruler.update(plot_area.camera, plot_area.canvas.get_logical_size()) - self.world_object.add(self._colorbar.world_object) - else: - self._colorbar = None - self._cmap = None + def _ruler_tick_map(self, bin_index, *args): + return f"{self._bin_centers_flanked[int(bin_index)]:.2f}" - def _make_colorbar(self, edges_flanked) -> ImageGraphic: - # use the histogram edge values as data for an - # image with 2 columns, this will be our colorbar! - colorbar_data = np.column_stack( - [ - np.linspace( - edges_flanked[0], edges_flanked[-1], ceil(np.ptp(edges_flanked)) - ) - ] - * 2 - ).astype(np.float32) - - colorbar_data /= self._scale_factor - - cbar = ImageGraphic( - data=colorbar_data, - vmin=self.vmin, - vmax=self.vmax, - cmap=self.images[0].cmap, - interpolation="linear", - offset=(-55, edges_flanked[0], -1), - ) + @property + def histogram(self) -> tuple[np.ndarray, np.ndarray]: + """histogram [frequency, bin_centers]. Frequency is flanked by 10 zeros on both sides""" + return self._freq_flanked, self._bin_centers_flanked - cbar.world_object.world.scale_x = 20 - self._cmap = self.images[0].cmap + @histogram.setter + def histogram( + self, histogram: tuple[np.ndarray, np.ndarray], limits: tuple[int, int] = None + ): + """set histogram with pre-compuated [frequency, edges], must have exactly 100 bins""" - return cbar + freq, edges = histogram - def _get_vmin_vmax_str(self) -> tuple[str, str]: - if self.vmin < 0.001 or self.vmin > 99_999: - vmin_str = f"{self.vmin:.2e}" - else: - vmin_str = f"{self.vmin:.2f}" + if freq.max() > 0: + # if the histogram is made from an empty array, then the max freq will be 0 + # we don't want to divide by 0 because then we just get nans + freq = freq / freq.max() - if self.vmax < 0.001 or self.vmax > 99_999: - vmax_str = f"{self.vmax:.2e}" - else: - vmax_str = f"{self.vmax:.2f}" + bin_centers = 0.5 * (edges[1:] + edges[:-1]) - return vmin_str, vmax_str + step = bin_centers[1] - bin_centers[0] - def _fpl_add_plot_area_hook(self, plot_area): - self._plot_area = plot_area - self._linear_region_selector._fpl_add_plot_area_hook(plot_area) - self._histogram_line._fpl_add_plot_area_hook(plot_area) + under_flank = np.linspace(bin_centers[0] - step * 10, bin_centers[0] - step, 10) + over_flank = np.linspace( + bin_centers[-1] + step, bin_centers[-1] + step * 10, 10 + ) + self._bin_centers_flanked[:] = np.concatenate( + [under_flank, bin_centers, over_flank] + ) + + self._freq_flanked[10:110] = freq - self._plot_area.auto_scale() - self._plot_area.controller.enabled = True + self._line.data[:, 0] = self._freq_flanked + self._colorbar.data = np.column_stack( + [self._bin_centers_flanked, self._bin_centers_flanked] + ) - def _calculate_histogram(self, data): + # self.vmin, self.vmax = bin_centers[0], bin_centers[-1] - # get a subsampled view of this array - data_ss = subsample_array(data, max_size=int(1e6)) # 1e6 is default - hist, edges = np.histogram(data_ss, bins=self._nbins) + if hasattr(self, "plot_area"): + self._ruler.update( + self._plot_area.camera, self._plot_area.canvas.get_logical_size() + ) - # used if data ptp <= 10 because event things get weird - # with tiny world objects due to floating point error - # so if ptp <= 10, scale up by a factor - data_interval = edges[-1] - edges[0] - self._scale_factor: int = max(1, 100 * int(10 / data_interval)) + @property + def images(self) -> tuple[ImageGraphic | ImageVolumeGraphic, ...] | None: + """get or set the managed images""" + return tuple(self._images) - edges = edges * self._scale_factor + @images.setter + def images(self, new_images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] | None): + self._disconnect_images() + self._images.clear() - bin_width = edges[1] - edges[0] + if new_images is None: + return - flank_nbins = int(self._nbins / self._flank_divisor) - flank_size = flank_nbins * bin_width + if isinstance(new_images, (ImageGraphic, ImageVolumeGraphic)): + new_images = [new_images] - flank_left = np.arange(edges[0] - flank_size, edges[0], bin_width) - flank_right = np.arange( - edges[-1] + bin_width, edges[-1] + flank_size, bin_width - ) + if not all( + [ + isinstance(image, (ImageGraphic, ImageVolumeGraphic)) + for image in new_images + ] + ): + raise TypeError - edges_flanked = np.concatenate((flank_left, edges, flank_right)) + for image in new_images: + if image.cmap is not None: + self._colorbar.visible = True + break + else: + self._colorbar.visible = False - hist_flanked = np.concatenate( - (np.zeros(flank_nbins), hist, np.zeros(flank_nbins)) - ) + self._images = list(new_images) - # scale 0-100 to make it easier to see - # float32 data can produce unnecessarily high values - hist_scale_value = hist_flanked.max() - if np.allclose(hist_scale_value, 0): - hist_scale_value = 1 - hist_scaled = hist_flanked / (hist_scale_value / 100) + # reset vmin, vmax using first image + self.vmin = self._images[0].vmin + self.vmax = self._images[0].vmax - if edges_flanked.size > hist_scaled.size: - # we don't care about accuracy here so if it's off by 1-2 bins that's fine - edges_flanked = edges_flanked[: hist_scaled.size] + if self._images[0].cmap is not None: + self._colorbar.cmap = self._images[0].cmap - return hist, edges, hist_scaled, edges_flanked + # connect event handlers + for image in self._images: + image.add_event_handler(self._image_event_handler, "vmin", "vmax") + image.add_event_handler(self._disconnect_images, "deleted") + if image.cmap is not None: + image.add_event_handler( + self._image_event_handler, "vmin", "vmax", "cmap" + ) - def _linear_region_handler(self, ev): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - selected_ixs = self._linear_region_selector.selection - vmin, vmax = selected_ixs[0], selected_ixs[1] - vmin, vmax = vmin / self._scale_factor, vmax / self._scale_factor - self.vmin, self.vmax = vmin, vmax + def _disconnect_images(self, *args): + """disconnect event handlers of the managed images""" + for image in self._images: + for ev, handlers in image.event_handlers: + if self._image_event_handler in handlers: + image.remove_event_handler(self._image_event_handler, ev) - def _image_cmap_handler(self, ev): - setattr(self, ev.type, ev.info["value"]) + def _image_event_handler(self, ev): + """when the image vmin, vmax, or cmap changes it will update the HistogramLUTTool""" + new_value = ev.info["value"] + setattr(self, ev.type, new_value) @property def cmap(self) -> str: - return self._cmap + """get or set the colormap, only for grayscale images""" + return self._colorbar.cmap @cmap.setter def cmap(self, name: str): - if self._colorbar is None: + if self._block_reentrance: return - with pause_events(*self.images): - for ig in self.images: - ig.cmap = name + if name is None: + return - self._cmap = name + self._block_reentrance = True + try: self._colorbar.cmap = name + with pause_events( + *self._images, event_handlers=[self._image_event_handler] + ): + for image in self._images: + if image.cmap is None: + # rgb(a) images have no cmap + continue + + image.cmap = name + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False + @property def vmin(self) -> float: - return self._vmin + """get or set the vmin, the lower contrast limit""" + # no offset or rotation so we can directly use the world space selection value + index = int(self._selector.selection[0]) + return self._bin_centers_flanked[index] @vmin.setter def vmin(self, value: float): - with pause_events(self._linear_region_selector, *self.images): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - self._linear_region_selector.selection = ( - value * self._scale_factor, - self._linear_region_selector.selection[1], - ) - for ig in self.images: - ig.vmin = value + if self._block_reentrance: + return + self._block_reentrance = True + try: + index_min = np.searchsorted(self._bin_centers_flanked, value) + with pause_events( + self._selector, + *self._images, + event_handlers=[ + self._selector_event_handler, + self._image_event_handler, + ], + ): + self._selector.selection = (index_min, self._selector.selection[1]) - self._vmin = value - if self._colorbar is not None: - self._colorbar.vmin = value + self._colorbar.vmin = value - vmin_str, vmax_str = self._get_vmin_vmax_str() - self._text_vmin.offset = (-120, self._linear_region_selector.selection[0], 0) - self._text_vmin.text = vmin_str + self._text_vmin.text = _format_value(value) + self._text_vmin.offset = (-0.45, self._selector.selection[0], 0) + + for image in self._images: + image.vmin = value + + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False @property def vmax(self) -> float: - return self._vmax + """get or set the vmax, the upper contrast limit""" + # no offset or rotation so we can directly use the world space selection value + index = int(self._selector.selection[1]) + return self._bin_centers_flanked[index] @vmax.setter def vmax(self, value: float): - with pause_events(self._linear_region_selector, *self.images): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - self._linear_region_selector.selection = ( - self._linear_region_selector.selection[0], - value * self._scale_factor, - ) - - for ig in self.images: - ig.vmax = value - - self._vmax = value - if self._colorbar is not None: - self._colorbar.vmax = value - - vmin_str, vmax_str = self._get_vmin_vmax_str() - self._text_vmax.offset = (-120, self._linear_region_selector.selection[1], 0) - self._text_vmax.text = vmax_str - - def set_data(self, data, reset_vmin_vmax: bool = True): - hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) - - line_data = np.column_stack([hist_scaled, edges_flanked]) - - # set x and y vals - self._histogram_line.data[:, :2] = line_data - - bounds = (edges[0], edges[-1]) - limits = (edges_flanked[0], edges_flanked[-11]) - origin = (hist_scaled.max() / 2, 0) - - if reset_vmin_vmax: - # reset according to the new data - self._linear_region_selector.limits = limits - self._linear_region_selector.selection = bounds - else: - with pause_events(self._linear_region_selector, *self.images): - # don't change the current selection - self._linear_region_selector.limits = limits - - self._data = weakref.proxy(data) - - if self._colorbar is not None: - self._colorbar.clear_event_handlers() - self.world_object.remove(self._colorbar.world_object) - - if self.images[0].cmap is not None: - self._colorbar: ImageGraphic = self._make_colorbar(edges_flanked) - self._colorbar.add_event_handler(self._open_cmap_picker, "click") + if self._block_reentrance: + return - self.world_object.add(self._colorbar.world_object) - else: - self._colorbar = None - self._cmap = None + self._block_reentrance = True + try: + index_max = np.searchsorted(self._bin_centers_flanked, value) + with pause_events( + self._selector, + *self._images, + event_handlers=[ + self._selector_event_handler, + self._image_event_handler, + ], + ): + self._selector.selection = (self._selector.selection[0], index_max) - # reset plotarea dims - self._plot_area.auto_scale() + self._colorbar.vmax = value - @property - def images(self) -> tuple[ImageGraphic | ImageVolumeGraphic]: - return self._images + self._text_vmax.text = _format_value(value) + self._text_vmax.offset = (-0.45, self._selector.selection[1], 0) - @images.setter - def images(self, images): - if isinstance(images, (ImageGraphic, ImageVolumeGraphic)): - images = (images,) - elif isinstance(images, Sequence): - if not all( - [isinstance(ig, (ImageGraphic, ImageVolumeGraphic)) for ig in images] - ): - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - else: - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) + for image in self._images: + image.vmax = value - if self._images is not None: - for ig in self._images: - # cleanup events from current image graphics - ig_events = _get_image_graphic_events(ig) - ig.remove_event_handler(self._image_cmap_handler, *ig_events) + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False - self._images = images + def _selector_event_handler(self, ev: GraphicFeatureEvent): + """when the selector's selctor has changed, it will update the vmin, vmax, or both""" + selection = ev.info["value"] + index_min = int(selection[0]) + vmin = self._bin_centers_flanked[index_min] - ig_events = _get_image_graphic_events(self._images[0]) + index_max = int(selection[1]) + vmax = self._bin_centers_flanked[index_max] - for ig in self.images: - ig.add_event_handler(self._image_cmap_handler, *ig_events) + match ev.info["change"]: + case "min": + self.vmin = vmin + case "max": + self.vmax = vmax + case _: + self.vmin, self.vmax = vmin, vmax def _open_cmap_picker(self, ev): + """open imgui cmap picker""" # check if right click if ev.button != 2: return @@ -431,7 +419,11 @@ def _open_cmap_picker(self, ev): self._plot_area.get_figure().open_popup("colormap-picker", pos, lut_tool=self) def _fpl_prepare_del(self): - self._linear_region_selector._fpl_prepare_del() - self._histogram_line._fpl_prepare_del() - del self._histogram_line - del self._linear_region_selector + """cleanup, need to disconnect events and remove image references for proper garbage collection""" + self._disconnect_images() + self._images.clear() + + for i in range(len(self._children)): + g = self._children.pop(0) + g._fpl_prepare_del() + del g diff --git a/fastplotlib/ui/_base.py b/fastplotlib/ui/_base.py index 3e763e08c..9767cf76f 100644 --- a/fastplotlib/ui/_base.py +++ b/fastplotlib/ui/_base.py @@ -123,8 +123,9 @@ def size(self) -> int | None: @size.setter def size(self, value): if not isinstance(value, int): - raise TypeError + raise TypeError(f"{self.__class__.__name__}.size must be an ") self._size = value + self._set_rect() @property def location(self) -> str: @@ -153,6 +154,7 @@ def height(self) -> int: def _set_rect(self, *args): self._x, self._y, self._width, self._height = self.get_rect() + self._figure._fpl_reset_layout() def get_rect(self) -> tuple[int, int, int, int]: """ diff --git a/fastplotlib/ui/right_click_menus/_colormap_picker.py b/fastplotlib/ui/right_click_menus/_colormap_picker.py index a80e5b2aa..9df26dcdc 100644 --- a/fastplotlib/ui/right_click_menus/_colormap_picker.py +++ b/fastplotlib/ui/right_click_menus/_colormap_picker.py @@ -154,7 +154,8 @@ def update(self): self._texture_height = (imgui.get_font_size()) - 2 if imgui.menu_item("Reset vmin-vmax", "", False)[0]: - self._lut_tool.images[0].reset_vmin_vmax() + for image in self._lut_tool.images: + image.reset_vmin_vmax() # add all the cmap options for cmap_type in COLORMAP_NAMES.keys(): diff --git a/fastplotlib/ui/right_click_menus/_standard_menu.py b/fastplotlib/ui/right_click_menus/_standard_menu.py index bb9e5bdef..33ab509d1 100644 --- a/fastplotlib/ui/right_click_menus/_standard_menu.py +++ b/fastplotlib/ui/right_click_menus/_standard_menu.py @@ -100,6 +100,12 @@ def update(self): ) self.get_subplot().camera.maintain_aspect = maintain_aspect + change, show_tooltips = imgui.menu_item( + "Show tooltips", "", self._figure.show_tooltips + ) + if change: + self._figure.show_tooltips = show_tooltips + imgui.separator() # toggles to flip axes cameras diff --git a/fastplotlib/utils/__init__.py b/fastplotlib/utils/__init__.py index dd527ca67..a513c791a 100644 --- a/fastplotlib/utils/__init__.py +++ b/fastplotlib/utils/__init__.py @@ -6,6 +6,7 @@ from .gpu import enumerate_adapters, select_adapter, print_wgpu_report from ._plot_helpers import * from .enums import * +from ._protocols import * @dataclass diff --git a/fastplotlib/utils/_protocols.py b/fastplotlib/utils/_protocols.py new file mode 100644 index 000000000..7ae63ed67 --- /dev/null +++ b/fastplotlib/utils/_protocols.py @@ -0,0 +1,15 @@ +from typing import Protocol, runtime_checkable + + +ARRAY_LIKE_ATTRS = ["shape", "ndim", "__getitem__"] + + +@runtime_checkable +class ArrayProtocol(Protocol): + @property + def ndim(self) -> int: ... + + @property + def shape(self) -> tuple[int, ...]: ... + + def __getitem__(self, key): ... diff --git a/fastplotlib/widgets/image_widget/__init__.py b/fastplotlib/widgets/image_widget/__init__.py index 70a1aa8ae..dc5daea55 100644 --- a/fastplotlib/widgets/image_widget/__init__.py +++ b/fastplotlib/widgets/image_widget/__init__.py @@ -2,6 +2,7 @@ if IMGUI: from ._widget import ImageWidget + from ._processor import NDImageProcessor else: diff --git a/fastplotlib/widgets/image_widget/_processor.py b/fastplotlib/widgets/image_widget/_processor.py new file mode 100644 index 000000000..0dce84a5e --- /dev/null +++ b/fastplotlib/widgets/image_widget/_processor.py @@ -0,0 +1,519 @@ +import inspect +from typing import Literal, Callable +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS + + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +class NDImageProcessor: + def __init__( + self, + data: ArrayLike | None, + n_display_dims: Literal[2, 3] = 2, + rgb: bool = False, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_sizes: tuple[int | None, ...] | int = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayLike], ArrayLike] = None, + compute_histogram: bool = True, + ): + """ + An ND image that supports computing window functions, and functions over spatial dimensions. + + Parameters + ---------- + data: ArrayLike + array-like data, must have 2 or more dimensions + + n_display_dims: int, 2 or 3, default 2 + number of display dimensions + + rgb: bool, default False + whether the image data is RGB(A) or not + + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable, optional + A function or a ``tuple`` of functions that are applied to a rolling window of the data. + + You can provide unique window functions for each dimension. If you want to apply a window function + only to a subset of the dimensions, put ``None`` to indicate no window function for a given dimension. + + A "window function" must take ``axis`` argument, which is an ``int`` that specifies the axis along which + the window function is applied. It must also take a ``keepdims`` argument which is a ``bool``. The window + function **must** return an array that has the same number of dimensions as the original ``data`` array, + therefore the size of the dimension along which the window was applied will reduce to ``1``. + + The output array-like type from a window function **must** support a ``.squeeze()`` method, but the + function itself should NOT squeeze the output array. + + window_sizes: tuple[int | None, ...], optional + ``tuple`` of ``int`` that specifies the window size for each dimension. + + window_order: tuple[int, ...] | None, optional + order in which to apply the window functions, by default just applies it from the left-most dim to the + right-most slider dim. + + spatial_func: Callable[[ArrayLike], ArrayLike] | None, optional + A function that is applied on the _spatial_ dimensions of the data array, i.e. the last 2 or 3 dimensions. + This function is applied after the window functions (if present). + + compute_histogram: bool, default True + Compute a histogram of the data, auto re-computes if window function propties or spatial_func changes. + Disable if slow. + + """ + # set as False until data, window funcs stuff and spatial func is all set + self._compute_histogram = False + + self.data = data + self.n_display_dims = n_display_dims + self.rgb = rgb + + self.window_funcs = window_funcs + self.window_sizes = window_sizes + self.window_order = window_order + + self._spatial_func = spatial_func + + self._compute_histogram = compute_histogram + self._recompute_histogram() + + @property + def data(self) -> ArrayLike | None: + """get or set the data array""" + return self._data + + @data.setter + def data(self, data: ArrayLike): + # check that all array-like attributes are present + if data is None: + self._data = None + return + + if not isinstance(data, ArrayProtocol): + raise TypeError( + f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" + f"{ARRAY_LIKE_ATTRS}, or they must be `None`" + ) + + if data.ndim < 2: + raise IndexError( + f"Image data must have a minimum of 2 dimensions, you have passed an array of shape: {data.shape}" + ) + + self._data = data + self._recompute_histogram() + + @property + def ndim(self) -> int: + if self.data is None: + return 0 + + return self.data.ndim + + @property + def shape(self) -> tuple[int, ...]: + if self._data is None: + return tuple() + + return self.data.shape + + @property + def rgb(self) -> bool: + """whether or not the data is rgb(a)""" + return self._rgb + + @rgb.setter + def rgb(self, rgb: bool): + if not isinstance(rgb, bool): + raise TypeError + + if rgb and self.ndim < 3: + raise IndexError( + f"require 3 or more dims for RGB, you have: {self.ndim} dims" + ) + + self._rgb = rgb + + @property + def n_slider_dims(self) -> int: + """number of slider dimensions""" + if self._data is None: + return 0 + + return self.ndim - self.n_display_dims - int(self.rgb) + + @property + def slider_dims(self) -> tuple[int, ...] | None: + """tuple indicating the slider dimension indices""" + if self.n_slider_dims == 0: + return None + + return tuple(range(self.n_slider_dims)) + + @property + def slider_dims_shape(self) -> tuple[int, ...] | None: + if self.n_slider_dims == 0: + return None + + return tuple(self.shape[i] for i in self.slider_dims) + + @property + def n_display_dims(self) -> Literal[2, 3]: + """get or set the number of display dimensions, `2` for 2D image and `3` for volume images""" + return self._n_display_dims + + # TODO: make n_display_dims settable, requires thinking about inserting and poping indices in ImageWidget + @n_display_dims.setter + def n_display_dims(self, n: Literal[2, 3]): + if not (n == 2 or n == 3): + raise ValueError( + f"`n_display_dims` must be an with a value of 2 or 3, you have passed: {n}" + ) + self._n_display_dims = n + self._recompute_histogram() + + @property + def max_n_display_dims(self) -> int: + """maximum number of possible display dims""" + # min 2, max 3, accounts for if data is None and ndim is 0 + return max(2, min(3, self.ndim - int(self.rgb))) + + @property + def display_dims(self) -> tuple[int, int] | tuple[int, int, int]: + """tuple indicating the display dimension indices""" + return tuple(range(self.data.ndim))[self.n_slider_dims :] + + @property + def window_funcs( + self, + ) -> tuple[WindowFuncCallable | None, ...] | None: + """get or set window functions, see docstring for details""" + return self._window_funcs + + @window_funcs.setter + def window_funcs( + self, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None, + ): + if window_funcs is None: + self._window_funcs = None + return + + if callable(window_funcs): + window_funcs = (window_funcs,) + + # if all are None + if all([f is None for f in window_funcs]): + self._window_funcs = None + return + + self._validate_window_func(window_funcs) + + self._window_funcs = tuple(window_funcs) + self._recompute_histogram() + + def _validate_window_func(self, funcs): + if isinstance(funcs, (tuple, list)): + for f in funcs: + if f is None: + pass + elif callable(f): + sig = inspect.signature(f) + + if "axis" not in sig.parameters or "keepdims" not in sig.parameters: + raise TypeError( + f"Each window function must take an `axis` and `keepdims` argument, " + f"you passed: {f} with the following function signature: {sig}" + ) + else: + raise TypeError( + f"`window_funcs` must be of type: tuple[Callable | None, ...], you have passed: {funcs}" + ) + + if not (len(funcs) == self.n_slider_dims or self.n_slider_dims == 0): + raise IndexError( + f"number of `window_funcs` must be the same as the number of slider dims: {self.n_slider_dims}, " + f"and you passed {len(funcs)} `window_funcs`: {funcs}" + ) + + @property + def window_sizes(self) -> tuple[int | None, ...] | None: + """get or set window sizes used for the corresponding window functions, see docstring for details""" + return self._window_sizes + + @window_sizes.setter + def window_sizes(self, window_sizes: tuple[int | None, ...] | int | None): + if window_sizes is None: + self._window_sizes = None + return + + if isinstance(window_sizes, int): + window_sizes = (window_sizes,) + + # if all are None + if all([w is None for w in window_sizes]): + self._window_sizes = None + return + + if not all([isinstance(w, (int)) or w is None for w in window_sizes]): + raise TypeError( + f"`window_sizes` must be of type: tuple[int | None, ...] | int | None, you have passed: {window_sizes}" + ) + + if not (len(window_sizes) == self.n_slider_dims or self.n_slider_dims == 0): + raise IndexError( + f"number of `window_sizes` must be the same as the number of slider dims, " + f"i.e. `data.ndim` - n_display_dims, your data array has {self.ndim} dimensions " + f"and you passed {len(window_sizes)} `window_sizes`: {window_sizes}" + ) + + # make all window sizes are valid numbers + _window_sizes = list() + for i, w in enumerate(window_sizes): + if w is None: + _window_sizes.append(None) + continue + + if w < 0: + raise ValueError( + f"negative window size passed, all `window_sizes` must be positive " + f"integers or `None`, you passed: {_window_sizes}" + ) + + if w == 0 or w == 1: + # this is not a real window, set as None + w = None + + elif w % 2 == 0: + # odd window sizes makes most sense + warn( + f"provided even window size: {w} in dim: {i}, adding `1` to make it odd" + ) + w += 1 + + _window_sizes.append(w) + + self._window_sizes = tuple(_window_sizes) + self._recompute_histogram() + + @property + def window_order(self) -> tuple[int, ...] | None: + """get or set dimension order in which window functions are applied""" + return self._window_order + + @window_order.setter + def window_order(self, order: tuple[int] | None): + if order is None: + self._window_order = None + return + + if order is not None: + if not all([d <= self.n_slider_dims for d in order]): + raise IndexError( + f"all `window_order` entries must be <= n_slider_dims\n" + f"`n_slider_dims` is: {self.n_slider_dims}, you have passed `window_order`: {order}" + ) + + if not all([d >= 0 for d in order]): + raise IndexError( + f"all `window_order` entires must be >= 0, you have passed: {order}" + ) + + self._window_order = tuple(order) + self._recompute_histogram() + + @property + def spatial_func(self) -> Callable[[ArrayLike], ArrayLike] | None: + """get or set a spatial_func function, see docstring for details""" + return self._spatial_func + + @spatial_func.setter + def spatial_func(self, func: Callable[[ArrayLike], ArrayLike] | None): + if not (callable(func) or func is not None): + raise TypeError( + f"`spatial_func` must be a callable or `None`, you have passed: {func}" + ) + + self._spatial_func = func + self._recompute_histogram() + + @property + def compute_histogram(self) -> bool: + return self._compute_histogram + + @compute_histogram.setter + def compute_histogram(self, compute: bool): + if compute: + if self._compute_histogram is False: + # compute a histogram + self._recompute_histogram() + self._compute_histogram = True + else: + self._compute_histogram = False + self._histogram = None + + @property + def histogram(self) -> tuple[np.ndarray, np.ndarray] | None: + """ + an estimate of the histogram of the data, (histogram_values, bin_edges). + + returns `None` if `compute_histogram` is `False` + """ + return self._histogram + + def _apply_window_function(self, indices: tuple[int, ...]) -> ArrayLike: + """applies the window functions for each dimension specified""" + # window size for each dim + winds = self._window_sizes + # window function for each dim + funcs = self._window_funcs + + if winds is None or funcs is None: + # no window funcs or window sizes, just slice data and return + # clamp to max bounds + indexer = list() + for dim, i in enumerate(indices): + i = min(self.shape[dim] - 1, i) + indexer.append(i) + + return self.data[tuple(indexer)] + + # order in which window funcs are applied + order = self._window_order + + if order is not None: + # remove any entries in `window_order` where the specified dim + # has a window function or window size specified as `None` + # example: + # window_sizes = (3, 2) + # window_funcs = (np.mean, None) + # order = (0, 1) + # `1` is removed from the order since that window_func is `None` + order = tuple( + d for d in order if winds[d] is not None and funcs[d] is not None + ) + else: + # sequential order + order = list() + for d in range(self.n_slider_dims): + if winds[d] is not None and funcs[d] is not None: + order.append(d) + + # the final indexer which will be used on the data array + indexer = list() + + for dim_index, (i, w, f) in enumerate(zip(indices, winds, funcs)): + # clamp i within the max bounds + i = min(self.shape[dim_index] - 1, i) + + if (w is not None) and (f is not None): + # specify slice window if both window size and function for this dim are not None + hw = int((w - 1) / 2) # half window + + # start index cannot be less than 0 + start = max(0, i - hw) + + # stop index cannot exceed the bounds of this dimension + stop = min(self.shape[dim_index] - 1, i + hw) + + s = slice(start, stop, 1) + else: + s = slice(i, i + 1, 1) + + indexer.append(s) + + # apply indexer to slice data with the specified windows + data_sliced = self.data[tuple(indexer)] + + # finally apply the window functions in the specified order + for dim in order: + f = funcs[dim] + + data_sliced = f(data_sliced, axis=dim, keepdims=True) + + return data_sliced + + def get(self, indices: tuple[int, ...]) -> ArrayLike | None: + """ + Get the data at the given index, process data through the window functions. + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + + Parameters + ---------- + indices: tuple[int, ...] + Get the processed data at this index. Must provide a value for each dimension. + Example: get((100, 5)) + + """ + if self.data is None: + return None + + if self.n_slider_dims != 0: + if len(indices) != self.n_slider_dims: + raise IndexError( + f"Must specify index for every slider dim, you have specified an index: {indices}\n" + f"But there are: {self.n_slider_dims} slider dims." + ) + # get output after processing through all window funcs + # squeeze to remove all dims of size 1 + window_output = self._apply_window_function(indices).squeeze() + else: + # data is a static image or volume + window_output = self.data + + # apply spatial_func + if self.spatial_func is not None: + final_output = self.spatial_func(window_output) + if final_output.ndim != (self.n_display_dims + int(self.rgb)): + raise IndexError( + f"Final output after of the `spatial_func` must match the number of display dims." + f"Output after `spatial_func` returned an array with {final_output.ndim} dims and " + f"of shape: {final_output.shape}, expected {self.n_display_dims} dims" + ) + else: + # check that output ndim after window functions matches display dims + final_output = window_output + if final_output.ndim != (self.n_display_dims + int(self.rgb)): + raise IndexError( + f"Final output after of the `window_funcs` must match the number of display dims." + f"Output after `window_funcs` returned an array with {window_output.ndim} dims and " + f"of shape: {window_output.shape}{' with rgb(a) channels' if self.rgb else ''}, " + f"expected {self.n_display_dims} dims" + ) + + return final_output + + def _recompute_histogram(self): + """ + + Returns + ------- + (histogram_values, bin_edges) + + """ + if not self._compute_histogram or self.data is None: + self._histogram = None + return + + if self.spatial_func is not None: + # don't subsample spatial dims if a spatial function is used + # spatial functions often operate on the spatial dims, ex: a gaussian kernel + # so their results require the full spatial resolution, the histogram of a + # spatially subsampled image will be very different + ignore_dims = self.display_dims + else: + ignore_dims = None + + sub = subsample_array(self.data, ignore_dims=ignore_dims) + sub_real = sub[~(np.isnan(sub) | np.isinf(sub))] + + self._histogram = np.histogram(sub_real, bins=100) diff --git a/fastplotlib/widgets/image_widget/_properties.py b/fastplotlib/widgets/image_widget/_properties.py new file mode 100644 index 000000000..060314439 --- /dev/null +++ b/fastplotlib/widgets/image_widget/_properties.py @@ -0,0 +1,139 @@ +from pprint import pformat +from typing import Iterable + +import numpy as np + +from ._processor import NDImageProcessor + + +class ImageWidgetProperty: + __class_getitem__ = classmethod(type(list[int])) + + def __init__( + self, + image_widget, + attribute: str, + ): + self._image_widget = image_widget + self._image_processors: list[NDImageProcessor] = image_widget._image_processors + self._attribute = attribute + + def _get_key(self, key: slice | int | np.integer | str) -> int | slice: + if not isinstance(key, (slice | int, np.integer, str)): + raise TypeError( + f"can index `{self._attribute}` only with a , , or a indicating the subplot name." + f"You tried to index with: {key}" + ) + + if isinstance(key, str): + for i, subplot in enumerate(self._image_widget.figure): + if subplot.name == key: + key = i + break + else: + raise IndexError(f"No subplot with given name: {key}") + + return key + + def __getitem__(self, key): + key = self._get_key(key) + # return image processor attribute at this index + if isinstance(key, (int, np.integer)): + return getattr(self._image_processors[key], self._attribute) + + # if it's a slice + processors = self._image_processors[key] + + return tuple(getattr(p, self._attribute) for p in processors) + + def __setitem__(self, key, value): + key = self._get_key(key) + + # get the values from the ImageWidget property + new_values = list(getattr(p, self._attribute) for p in self._image_processors) + + # set the new value at this slice + new_values[key] = value + + # call the setter + setattr(self._image_widget, self._attribute, new_values) + + def __iter__(self): + for image_processor in self._image_processors: + yield getattr(image_processor, self._attribute) + + def __repr__(self): + return f"{self._attribute}: {pformat(self[:])}" + + def __eq__(self, other): + return self[:] == other + + +class Indices: + def __init__( + self, + indices: list[int], + image_widget, + ): + self._data = indices + + self._image_widget = image_widget + + def __iter__(self): + for i in self._data: + yield i + + def _parse_key(self, key: int | np.integer | str) -> int: + if not isinstance(key, (int, np.integer, str)): + raise TypeError( + f"indices can only be indexed with or types, you have used: {key}" + ) + + if isinstance(key, str): + # get integer index from user's names + names = self._image_widget._slider_dim_names + if key not in names: + raise KeyError( + f"dim with name: {key} not found in slider_dim_names, current names are: {names}" + ) + + key = names.index(key) + + return key + + def __getitem__(self, key: int | np.integer | str) -> int | tuple[int]: + if isinstance(key, str): + key = self._parse_key(key) + + return self._data[key] + + def __setitem__(self, key, value): + key = self._parse_key(key) + + if not isinstance(value, (int, np.integer)): + raise TypeError( + f"indices values can only be set with integers, you have tried to set the value: {value}" + ) + + new_indices = list(self._data) + new_indices[key] = value + + self._image_widget.indices = new_indices + + def _fpl_set(self, values): + self._data[:] = values + + def pop_dim(self): + self._data.pop(0) + + def push_dim(self): + self._data.insert(0, 0) + + def __len__(self): + return len(self._data) + + def __eq__(self, other): + return self._data == other + + def __repr__(self): + return f"indices: {self._data}" diff --git a/fastplotlib/widgets/image_widget/_sliders.py b/fastplotlib/widgets/image_widget/_sliders.py index 393b13273..1945b8cfb 100644 --- a/fastplotlib/widgets/image_widget/_sliders.py +++ b/fastplotlib/widgets/image_widget/_sliders.py @@ -11,50 +11,66 @@ def __init__(self, figure, size, location, title, image_widget): super().__init__(figure=figure, size=size, location=location, title=title) self._image_widget = image_widget + n_sliders = self._image_widget.n_sliders + # whether or not a dimension is in play mode - self._playing: dict[str, bool] = {"t": False, "z": False} + self._playing: list[bool] = [False] * n_sliders # approximate framerate for playing - self._fps: dict[str, int] = {"t": 20, "z": 20} + self._fps: list[int] = [20] * n_sliders + # framerate converted to frame time - self._frame_time: dict[str, float] = {"t": 1 / 20, "z": 1 / 20} + self._frame_time: list[float] = [1 / 20] * n_sliders # last timepoint that a frame was displayed from a given dimension - self._last_frame_time: dict[str, float] = {"t": 0, "z": 0} + self._last_frame_time: list[float] = [perf_counter()] * n_sliders + # loop playback self._loop = False - if "RTD_BUILD" in os.environ.keys(): - if os.environ["RTD_BUILD"] == "1": - self._playing["t"] = True + # auto-plays the ImageWidget's left-most dimension in docs galleries + if "DOCS_BUILD" in os.environ.keys(): + if os.environ["DOCS_BUILD"] == "1": + self._playing[0] = True self._loop = True - def set_index(self, dim: str, index: int): - """set the current_index of the ImageWidget""" + self.pause = False + + def pop_dim(self): + """pop right most dim""" + i = 0 # len(self._image_widget.indices) - 1 + for l in [self._playing, self._fps, self._frame_time, self._last_frame_time]: + l.pop(i) + + def push_dim(self): + """push a new dim""" + self._playing.insert(0, False) + self._fps.insert(0, 20) + self._frame_time.insert(0, 1 / 20) + self._last_frame_time.insert(0, perf_counter()) + + def set_index(self, dim: int, new_index: int): + """set the index of the ImageWidget""" # make sure the max index for this dim is not exceeded - max_index = self._image_widget._dims_max_bounds[dim] - 1 - if index > max_index: + max_index = self._image_widget.bounds[dim] - 1 + if new_index > max_index: if self._loop: # loop back to index zero if looping is enabled - index = 0 + new_index = 0 else: # if looping not enabled, stop playing this dimension self._playing[dim] = False return - # set current_index - self._image_widget.current_index = {dim: min(index, max_index)} + # set new index + new_indices = list(self._image_widget.indices) + new_indices[dim] = new_index + self._image_widget.indices = new_indices def update(self): """called on every render cycle to update the GUI elements""" - # store the new index of the image widget ("t" and "z") - new_index = dict() - - # flag if the index changed - flag_index_changed = False - # reset vmin-vmax using full orig data if imgui.button(label=fa.ICON_FA_CIRCLE_HALF_STROKE + fa.ICON_FA_FILM): self._image_widget.reset_vmin_vmax() @@ -72,7 +88,7 @@ def update(self): now = perf_counter() # buttons and slider UI elements for each dim - for dim in self._image_widget.slider_dims: + for dim in range(self._image_widget.n_sliders): imgui.push_id(f"{self._id_counter}_{dim}") if self._playing[dim]: @@ -83,7 +99,7 @@ def update(self): # if in play mode and enough time has elapsed w.r.t. the desired framerate, increment the index if now - self._last_frame_time[dim] >= self._frame_time[dim]: - self.set_index(dim, self._image_widget.current_index[dim] + 1) + self.set_index(dim, self._image_widget.indices[dim] + 1) self._last_frame_time[dim] = now else: @@ -97,12 +113,12 @@ def update(self): imgui.same_line() # step back one frame button if imgui.button(label=fa.ICON_FA_BACKWARD_STEP) and not self._playing[dim]: - self.set_index(dim, self._image_widget.current_index[dim] - 1) + self.set_index(dim, self._image_widget.indices[dim] - 1) imgui.same_line() # step forward one frame button if imgui.button(label=fa.ICON_FA_FORWARD_STEP) and not self._playing[dim]: - self.set_index(dim, self._image_widget.current_index[dim] + 1) + self.set_index(dim, self._image_widget.indices[dim] + 1) imgui.same_line() # stop button @@ -137,10 +153,15 @@ def update(self): self._fps[dim] = value self._frame_time[dim] = 1 / value - val = self._image_widget.current_index[dim] - vmax = self._image_widget._dims_max_bounds[dim] - 1 + val = self._image_widget.indices[dim] + vmax = self._image_widget.bounds[dim] - 1 + + dim_name = dim + if self._image_widget._slider_dim_names is not None: + if dim < len(self._image_widget._slider_dim_names): + dim_name = self._image_widget._slider_dim_names[dim] - imgui.text(f"{dim}: ") + imgui.text(f"dim '{dim_name}:' ") imgui.same_line() # so that slider occupies full width imgui.set_next_item_width(self.width * 0.85) @@ -154,18 +175,12 @@ def update(self): # slider for this dimension changed, index = imgui.slider_int( - f"{dim}", v=val, v_min=0, v_max=vmax, flags=flags + f"d: {dim}", v=val, v_min=0, v_max=vmax, flags=flags ) - new_index[dim] = index - - # if the slider value changed for this dimension - flag_index_changed |= changed + if changed: + new_indices = list(self._image_widget.indices) + new_indices[dim] = index + self._image_widget.indices = new_indices imgui.pop_id() - - if flag_index_changed: - # if any slider dim changed set the new index of the image widget - self._image_widget.current_index = new_index - - self.size = int(imgui.get_window_height()) diff --git a/fastplotlib/widgets/image_widget/_widget.py b/fastplotlib/widgets/image_widget/_widget.py index 715fe3489..7db265c0c 100644 --- a/fastplotlib/widgets/image_widget/_widget.py +++ b/fastplotlib/widgets/image_widget/_widget.py @@ -1,5 +1,4 @@ -from copy import deepcopy -from typing import Callable +from typing import Callable, Sequence, Literal from warnings import warn import numpy as np @@ -7,297 +6,50 @@ from rendercanvas import BaseRenderCanvas from ...layouts import ImguiFigure as Figure -from ...graphics import ImageGraphic -from ...utils import calculate_figure_shape, quick_min_max +from ...graphics import ImageGraphic, ImageVolumeGraphic +from ...utils import calculate_figure_shape, quick_min_max, ArrayProtocol from ...tools import HistogramLUTTool from ._sliders import ImageWidgetSliders +from ._processor import NDImageProcessor, WindowFuncCallable +from ._properties import ImageWidgetProperty, Indices -# Number of dimensions that represent one image/one frame -# For grayscale shape will be [n_rows, n_cols], i.e. 2 dims -# For RGB(A) shape will be [n_rows, n_cols, c] where c is of size 3 (RGB) or 4 (RGBA) -IMAGE_DIM_COUNTS = {"gray": 2, "rgb": 3} - -# Map boolean (indicating whether we use RGB or grayscale) to the string. Used to index RGB_DIM_MAP -RGB_BOOL_MAP = {False: "gray", True: "rgb"} - -# Dimensions that can be scrolled from a given data array -SCROLLABLE_DIMS_ORDER = { - 0: "", - 1: "t", - 2: "tz", -} - -ALLOWED_SLIDER_DIMS = {0: "t", 1: "z"} - -ALLOWED_WINDOW_DIMS = {"t", "z"} - - -def _is_arraylike(obj) -> bool: - """ - Checks if the object is array-like. - For now just checks if obj has `__getitem__()` - """ - for attr in ["__getitem__", "shape", "ndim"]: - if not hasattr(obj, attr): - return False - - return True - - -class _WindowFunctions: - """Stores window function and window size""" - - def __init__(self, image_widget, func: callable, window_size: int): - self._image_widget = image_widget - self._func = None - self.func = func - - self._window_size = 0 - self.window_size = window_size - - @property - def func(self) -> callable: - """Get or set the function""" - return self._func - - @func.setter - def func(self, func: callable): - self._func = func - - # force update - self._image_widget.current_index = self._image_widget.current_index - - @property - def window_size(self) -> int: - """Get or set window size""" - return self._window_size - - @window_size.setter - def window_size(self, ws: int): - if ws is None: - self._window_size = None - return - - if not isinstance(ws, int): - raise TypeError("window size must be an int") - - if ws < 3: - warn( - f"Invalid 'window size' value for function: {self.func}, " - f"setting 'window size' = None for this function. " - f"Valid values are integers >= 3." - ) - self.window_size = None - return - - if ws % 2 == 0: - ws += 1 - - self._window_size = ws - - self._image_widget.current_index = self._image_widget.current_index - - def __repr__(self): - return f"func: {self.func}, window_size: {self.window_size}" +IMGUI_SLIDER_HEIGHT = 49 class ImageWidget: - @property - def figure(self) -> Figure: - """ - ``Figure`` used by `ImageWidget`. - """ - return self._figure - - @property - def managed_graphics(self) -> list[ImageGraphic]: - """List of ``ImageWidget`` managed graphics.""" - iw_managed = list() - for subplot in self.figure: - # empty subplots will not have any image widget data - if len(subplot.graphics) > 0: - iw_managed.append(subplot["image_widget_managed"]) - return iw_managed - - @property - def cmap(self) -> list[str]: - cmaps = list() - for g in self.managed_graphics: - cmaps.append(g.cmap) - - return cmaps - - @cmap.setter - def cmap(self, names: str | list[str]): - if isinstance(names, list): - if not all([isinstance(n, str) for n in names]): - raise TypeError( - f"Must pass cmap name as a `str` of list of `str`, you have passed:\n{names}" - ) - - if not len(names) == len(self.managed_graphics): - raise IndexError( - f"If passing a list of cmap names, the length of the list must be the same as the number of " - f"image widget subplots. You have passed: {len(names)} cmap names and have " - f"{len(self.managed_graphics)} image widget subplots" - ) - - for name, g in zip(names, self.managed_graphics): - g.cmap = name - - elif isinstance(names, str): - for g in self.managed_graphics: - g.cmap = names - - @property - def data(self) -> list[np.ndarray]: - """data currently displayed in the widget""" - return self._data - - @property - def ndim(self) -> int: - """Number of dimensions of grayscale data displayed in the widget (it will be 1 more for RGB(A) data)""" - return self._ndim - - @property - def n_scrollable_dims(self) -> list[int]: - """ - list indicating the number of dimenensions that are scrollable for each data array - All other dimensions are frame/image data, i.e. [rows, cols] or [rows, cols, rgb(a)] - """ - return self._n_scrollable_dims - - @property - def slider_dims(self) -> list[str]: - """the dimensions that the sliders index""" - return self._slider_dims - - @property - def current_index(self) -> dict[str, int]: - """ - Get or set the current index - - Returns - ------- - index: Dict[str, int] - | ``dict`` for indexing each dimension, provide a ``dict`` with indices for all dimensions used by sliders - or only a subset of dimensions used by the sliders. - | example: if you have sliders for dims "t" and "z", you can pass either ``{"t": 10}`` to index to position - 10 on dimension "t" or ``{"t": 5, "z": 20}`` to index to position 5 on dimension "t" and position 20 on - dimension "z" simultaneously. - - """ - return self._current_index - - @current_index.setter - def current_index(self, index: dict[str, int]): - if not self._initialized: - return - - if self._reentrant_block: - return - - try: - self._reentrant_block = True # block re-execution until current_index has *fully* completed execution - if not set(index.keys()).issubset(set(self._current_index.keys())): - raise KeyError( - f"All dimension keys for setting `current_index` must be present in the widget sliders. " - f"The dimensions currently used for sliders are: {list(self.current_index.keys())}" - ) - - for k, val in index.items(): - if not isinstance(val, int): - raise TypeError("Indices for all dimensions must be int") - if val < 0: - raise IndexError( - "negative indexing is not supported for ImageWidget" - ) - if val > self._dims_max_bounds[k]: - raise IndexError( - f"index {val} is out of bounds for dimension '{k}' " - f"which has a max bound of: {self._dims_max_bounds[k]}" - ) - - self._current_index.update(index) - - for i, (ig, data) in enumerate(zip(self.managed_graphics, self.data)): - frame = self._process_indices(data, self._current_index) - frame = self._process_frame_apply(frame, i) - ig.data = frame - - # call any event handlers - for handler in self._current_index_changed_handlers: - handler(self.current_index) - except Exception as exc: - # raise original exception - raise exc # current_index setter has raised. The lines above below are probably more relevant! - finally: - # set_value has finished executing, now allow future executions - self._reentrant_block = False - - @property - def n_img_dims(self) -> list[int]: - """ - list indicating the number of dimensions that contain image/single frame data for each data array. - if 2: data are grayscale, i.e. [x, y] dims, if 3: data are [x, y, c] where c is RGB or RGBA, - this is the complement of `n_scrollable_dims` - """ - return self._n_img_dims - - def _get_n_scrollable_dims(self, curr_arr: np.ndarray, rgb: bool) -> list[int]: - """ - For a given ``array`` displayed in the ImageWidget, this function infers how many of the dimensions are - supported by sliders (aka scrollable). Ex: "xy" data has 0 scrollable dims, "txy" has 1, "tzxy" has 2. - - Parameters - ---------- - curr_arr: np.ndarray - np.ndarray or a list of array-like - - rgb: bool - True if we view this as RGB(A) and False if grayscale - - Returns - ------- - int - Number of scrollable dimensions for each ``array`` in the dataset. - """ - - n_img_dims = IMAGE_DIM_COUNTS[RGB_BOOL_MAP[rgb]] - # Make sure each image stack at least ``n_img_dims`` dimensions - if len(curr_arr.shape) < n_img_dims: - raise ValueError( - f"Your array has shape {curr_arr.shape} " - f"but you specified that each image in your array is {n_img_dims}D " - ) - - # If RGB(A), last dim must be 3 or 4 - if n_img_dims == 3: - if not (curr_arr.shape[-1] == 3 or curr_arr.shape[-1] == 4): - raise ValueError( - f"Expected size 3 or 4 for last dimension of RGB(A) array, got: {curr_arr.shape[-1]}." - ) - - n_scrollable_dims = len(curr_arr.shape) - n_img_dims - - if n_scrollable_dims not in SCROLLABLE_DIMS_ORDER.keys(): - raise ValueError(f"Array had shape {curr_arr.shape} which is not supported") - - return n_scrollable_dims - def __init__( self, - data: np.ndarray | list[np.ndarray], - window_funcs: dict[str, tuple[Callable, int]] = None, - frame_apply: Callable | dict[int, Callable] = None, + data: ArrayProtocol | Sequence[ArrayProtocol | None] | None, + processors: NDImageProcessor | Sequence[NDImageProcessor] = NDImageProcessor, + n_display_dims: Literal[2, 3] | Sequence[Literal[2, 3]] = 2, + slider_dim_names: Sequence[str] | None = None, # dim names left -> right + rgb: bool | Sequence[bool] = False, + cmap: str | Sequence[str] = "plasma", + window_funcs: ( + tuple[WindowFuncCallable | None, ...] + | WindowFuncCallable + | None + | Sequence[ + tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None + ] + ) = None, + window_sizes: ( + tuple[int | None, ...] | Sequence[tuple[int | None, ...] | None] + ) = None, + window_order: tuple[int, ...] | Sequence[tuple[int, ...] | None] = None, + spatial_func: ( + Callable[[ArrayProtocol], ArrayProtocol] + | Sequence[Callable[[ArrayProtocol], ArrayProtocol]] + | None + ) = None, + sliders_dim_order: Literal["right", "left"] = "right", figure_shape: tuple[int, int] = None, - names: list[str] = None, + names: Sequence[str] = None, figure_kwargs: dict = None, histogram_widget: bool = True, - rgb: bool | list[bool] = None, - cmap: str = "plasma", - graphic_kwargs: dict = None, + histogram_init_quantile: int = (0, 100), + graphic_kwargs: dict | Sequence[dict] = None, ): """ This widget facilitates high-level navigation through image stacks, which are arrays containing one or more @@ -307,36 +59,22 @@ def __init__( Allowed dimensions orders for each image stack: Note that each has a an optional (c) channel which refers to RGB(A) a channel. So this channel should be either 3 or 4. - ======= ========== - n_dims dims order - ======= ========== - 2 "xy(c)" - 3 "txy(c)" - 4 "tzxy(c)" - ======= ========== - Parameters ---------- - data: Union[np.ndarray, List[np.ndarray] - array-like or a list of array-like - - window_funcs: dict[str, tuple[Callable, int]], i.e. {"t" or "z": (callable, int)} - | Apply function(s) with rolling windows along "t" and/or "z" dimensions of the `data` arrays. - | Pass a dict in the form: {dimension: (func, window_size)}, `func` must take a slice of the data array as - | the first argument and must take `axis` as a kwarg. - | Ex: mean along "t" dimension: {"t": (np.mean, 11)}, if `current_index` of "t" is 50, it will pass frames - | 45 to 55 to `np.mean` with `axis=0`. - | Ex: max along z dim: {"z": (np.max, 3)}, passes current, previous & next frame to `np.max` with `axis=1` - - frame_apply: Union[callable, Dict[int, callable]] - | Apply function(s) to `data` arrays before to generate final 2D image that is displayed. - | Ex: apply a spatial gaussian filter - | Pass a single function or a dict of functions to apply to each array individually - | examples: ``{array_index: to_grayscale}``, ``{0: to_grayscale, 2: threshold_img}`` - | "array_index" is the position of the corresponding array in the data list. - | if `window_funcs` is used, then this function is applied after `window_funcs` - | this function must be a callable that returns a 2D array - | example use case: converting an RGB frame from video to a 2D grayscale frame + data: ArrayProtocol | Sequence[ArrayProtocol | None] | None + array-like or a list of array-like, each array must have a minimum of 2 dimensions + + processors: NDImageProcessor | Sequence[NDImageProcessor], default NDImageProcessor + The image processors used for each n-dimensional data array + + n_display_dims: Literal[2, 3] | Sequence[Literal[2, 3]], default 2 + number of display dimensions + + slider_dim_names: Sequence[str], optional + optional list/tuple of names for each slider dim + + rgb: bool | Sequence[bool], default + whether or not each data array represents RGB(A) images figure_shape: Optional[Tuple[int, int]] manually provide the shape for the Figure, otherwise the number of rows and columns is estimated @@ -358,155 +96,221 @@ def __init__( passed to each ImageGraphic in the ImageWidget figure subplots """ - self._initialized = False if figure_kwargs is None: figure_kwargs = dict() - if _is_arraylike(data): + if isinstance(data, ArrayProtocol) or (data is None): data = [data] - if isinstance(data, list): + elif isinstance(data, (list, tuple)): # verify that it's a list of np.ndarray - if all([_is_arraylike(d) for d in data]): - # Grid computations - if figure_shape is None: - if "shape" in figure_kwargs: - figure_shape = figure_kwargs["shape"] - else: - figure_shape = calculate_figure_shape(len(data)) - - # Regardless of how figure_shape is computed, below code - # verifies that figure shape is large enough for the number of image arrays passed - if figure_shape[0] * figure_shape[1] < len(data): - original_shape = (figure_shape[0], figure_shape[1]) - figure_shape = calculate_figure_shape(len(data)) - warn( - f"Original `figure_shape` was: {original_shape} " - f" but data length is {len(data)}" - f" Resetting figure shape to: {figure_shape}" - ) - - self._data: list[np.ndarray] = data - - # Establish number of image dimensions and number of scrollable dimensions for each array - if rgb is None: - rgb = [False] * len(self.data) - if isinstance(rgb, bool): - rgb = [rgb] * len(self.data) - if not isinstance(rgb, list): - raise TypeError( - f"`rgb` parameter must be a bool or list of bool, a <{type(rgb)}> was provided" - ) - if not len(rgb) == len(self.data): - raise ValueError( - f"len(rgb) != len(data), {len(rgb)} != {len(self.data)}. These must be equal" - ) - - self._rgb = rgb - - self._n_img_dims = [ - IMAGE_DIM_COUNTS[RGB_BOOL_MAP[self._rgb[i]]] - for i in range(len(self.data)) - ] - - self._n_scrollable_dims = [ - self._get_n_scrollable_dims(self.data[i], self._rgb[i]) - for i in range(len(self.data)) - ] - - # Define ndim of ImageWidget instance as largest number of scrollable dims + 2 (grayscale dimensions) - self._ndim = ( - max( - [ - self.n_scrollable_dims[i] - for i in range(len(self.n_scrollable_dims)) - ] - ) - + IMAGE_DIM_COUNTS[RGB_BOOL_MAP[False]] + if not all([isinstance(d, ArrayProtocol) or d is None for d in data]): + raise TypeError( + f"`data` must be an array-like type or a list/tuple of array-like or None. " + f"You have passed the following type {type(data)}" ) - if names is not None: - if not all([isinstance(n, str) for n in names]): - raise TypeError( - "optional argument `names` must be a list of str" - ) + else: + raise TypeError( + f"`data` must be an array-like type or a list/tuple of array-like or None. " + f"You have passed the following type {type(data)}" + ) - if len(names) != len(self.data): - raise ValueError( - "number of `names` for subplots must be same as the number of data arrays" - ) + if issubclass(processors, NDImageProcessor): + processors = [processors] * len(data) - else: + elif isinstance(processors, (tuple, list)): + if not all([issubclass(p, NDImageProcessor) for p in processors]): raise TypeError( - f"If passing a list to `data` all elements must be an " - f"array-like type representing an n-dimensional image. " - f"You have passed the following types:\n" - f"{[type(a) for a in data]}" + f"`processors` must be a `NDImageProcess` class, a subclass of `NDImageProcessor`, or a " + f"list/tuple of `NDImageProcess` subclasses. You have passed: {processors}" ) + else: raise TypeError( - f"`data` must be an array-like type or a list of array-like." - f"You have passed the following type {type(data)}" + f"`processors` must be a `NDImageProcess` class, a subclass of `NDImageProcessor`, or a " + f"list/tuple of `NDImageProcess` subclasses. You have passed: {processors}" ) - # Sliders are made for all dimensions except the image dimensions - self._slider_dims = list() - max_scrollable = max( - [self.n_scrollable_dims[i] for i in range(len(self.n_scrollable_dims))] - ) - for dim in range(max_scrollable): - if dim in ALLOWED_SLIDER_DIMS.keys(): - self.slider_dims.append(ALLOWED_SLIDER_DIMS[dim]) + # subplot layout + if figure_shape is None: + if "shape" in figure_kwargs: + figure_shape = figure_kwargs["shape"] + else: + figure_shape = calculate_figure_shape(len(data)) - self._frame_apply: dict[int, callable] = dict() + # Regardless of how figure_shape is computed, below code + # verifies that figure shape is large enough for the number of image arrays passed + if figure_shape[0] * figure_shape[1] < len(data): + original_shape = (figure_shape[0], figure_shape[1]) + figure_shape = calculate_figure_shape(len(data)) + warn( + f"Original `figure_shape` was: {original_shape} " + f" but data length is {len(data)}" + f" Resetting figure shape to: {figure_shape}" + ) - if frame_apply is not None: - if callable(frame_apply): - self._frame_apply = frame_apply + elif isinstance(rgb, bool): + rgb = [rgb] * len(data) - elif isinstance(frame_apply, dict): - self._frame_apply: dict[int, callable] = dict.fromkeys( - list(range(len(self.data))) - ) + if not all([isinstance(v, bool) for v in rgb]): + raise TypeError( + f"`rgb` parameter must be a bool or a Sequence of bool, you have passed: {rgb}" + ) - # dict of {array: dims_order_str} - for data_ix in list(frame_apply.keys()): - if not isinstance(data_ix, int): - raise TypeError("`frame_apply` dict keys must be ") - try: - self._frame_apply[data_ix] = frame_apply[data_ix] - except Exception: - raise IndexError( - f"key index {data_ix} out of bounds for `frame_apply`, the bounds are 0 - {len(self.data)}" - ) - else: - raise TypeError( - f"`frame_apply` must be a callable or , " - f"you have passed a: <{type(frame_apply)}>" + if not len(rgb) == len(data): + raise ValueError( + f"len(rgb) != len(data), {len(rgb)} != {len(data)}. These must be equal" + ) + + if names is not None: + if not all([isinstance(n, str) for n in names]): + raise TypeError("optional argument `names` must be a Sequence of str") + + if len(names) != len(data): + raise ValueError( + "number of `names` for subplots must be same as the number of data arrays" ) - # current_index stores {dimension_index: slice_index} for every dimension - self._current_index: dict[str, int] = {sax: 0 for sax in self.slider_dims} + # verify window funcs + if window_funcs is None: + win_funcs = [None] * len(data) + + elif callable(window_funcs) or all( + [callable(f) or f is None for f in window_funcs] + ): + # across all data arrays + # one window function defined for all dims, or window functions defined per-dim + win_funcs = [window_funcs] * len(data) - self._window_funcs = None - self.window_funcs = window_funcs + # if the above two clauses didn't trigger, then window_funcs defined per-dim, per data array + elif len(window_funcs) != len(data): + raise IndexError + else: + win_funcs = window_funcs - # get max bound for all data arrays for all slider dimensions and ensure compatibility across slider dims - self._dims_max_bounds: dict[str, int] = {k: 0 for k in self.slider_dims} - for i, _dim in enumerate(list(self._dims_max_bounds.keys())): - for array, partition in zip(self.data, self.n_scrollable_dims): - if partition <= i: - continue + # verify window sizes + if window_sizes is None: + win_sizes = [window_sizes] * len(data) + + elif isinstance(window_sizes, int): + win_sizes = [window_sizes] * len(data) + + elif all([isinstance(size, int) or size is None for size in window_sizes]): + # window sizes defined per-dim across all data arrays + win_sizes = [window_sizes] * len(data) + + elif len(window_sizes) != len(data): + # window sizes defined per-dim, per data array + raise IndexError + else: + win_sizes = window_sizes + + # verify window orders + if window_order is None: + win_order = [None] * len(data) + + elif all([isinstance(o, int) for o in order]): + # window order defined per-dim across all data arrays + win_order = [window_order] * len(data) + + elif len(window_order) != len(data): + raise IndexError + + else: + win_order = window_order + + # verify spatial_func + if spatial_func is None: + spatial_func = [None] * len(data) + + elif callable(spatial_func): + # same spatial_func for all data arrays + spatial_func = [spatial_func] * len(data) + + elif len(spatial_func) != len(data): + raise IndexError + + else: + spatial_func = spatial_func + + # verify number of display dims + if isinstance(n_display_dims, (int, np.integer)): + n_display_dims = [n_display_dims] * len(data) + + elif isinstance(n_display_dims, (tuple, list)): + if not all([isinstance(n, (int, np.integer)) for n in n_display_dims]): + raise TypeError + + if len(n_display_dims) != len(data): + raise IndexError + else: + raise TypeError + + n_display_dims = tuple(n_display_dims) + + if sliders_dim_order not in ("right",): + raise ValueError( + f"Only 'right' slider dims order is currently supported, you passed: {sliders_dim_order}" + ) + self._sliders_dim_order = sliders_dim_order + + self._slider_dim_names = None + self.slider_dim_names = slider_dim_names + + self._histogram_widget = histogram_widget + + # make NDImageArrays + self._image_processors: list[NDImageProcessor] = list() + for i in range(len(data)): + cls = processors[i] + image_processor = cls( + data=data[i], + rgb=rgb[i], + n_display_dims=n_display_dims[i], + window_funcs=win_funcs[i], + window_sizes=win_sizes[i], + window_order=win_order[i], + spatial_func=spatial_func[i], + compute_histogram=self._histogram_widget, + ) + + self._image_processors.append(image_processor) + + self._data = ImageWidgetProperty(self, "data") + self._rgb = ImageWidgetProperty(self, "rgb") + self._n_display_dims = ImageWidgetProperty(self, "n_display_dims") + self._window_funcs = ImageWidgetProperty(self, "window_funcs") + self._window_sizes = ImageWidgetProperty(self, "window_sizes") + self._window_order = ImageWidgetProperty(self, "window_order") + self._spatial_func = ImageWidgetProperty(self, "spatial_func") + + if len(set(n_display_dims)) > 1: + # assume user wants one controller for 2D images and another for 3D image volumes + n_subplots = np.prod(figure_shape) + controller_ids = [0] * n_subplots + controller_types = ["panzoom"] * n_subplots + + for i in range(len(data)): + if n_display_dims[i] == 2: + controller_ids[i] = 1 else: - if 0 < self._dims_max_bounds[_dim] != array.shape[i]: - raise ValueError(f"Two arrays differ along dimension {_dim}") - else: - self._dims_max_bounds[_dim] = max( - self._dims_max_bounds[_dim], array.shape[i] - ) + controller_ids[i] = 2 + controller_types[i] = "orbit" - figure_kwargs_default = {"controller_ids": "sync", "names": names} + # needs to be a list of list + controller_ids = [controller_ids] + + else: + controller_ids = "sync" + controller_types = None + + figure_kwargs_default = { + "controller_ids": controller_ids, + "controller_types": controller_types, + "names": names, + } # update the default kwargs with any user-specified kwargs # user specified kwargs will overwrite the defaults @@ -514,27 +318,48 @@ def __init__( figure_kwargs_default["shape"] = figure_shape if graphic_kwargs is None: - graphic_kwargs = dict() + graphic_kwargs = [dict()] * len(data) + + elif isinstance(graphic_kwargs, dict): + graphic_kwargs = [graphic_kwargs] * len(data) + + elif len(graphic_kwargs) != len(data): + raise IndexError + + if cmap is None: + cmap = [None] * len(data) - graphic_kwargs.update({"cmap": cmap}) + elif isinstance(cmap, str): + cmap = [cmap] * len(data) - vmin_specified, vmax_specified = None, None - if "vmin" in graphic_kwargs.keys(): - vmin_specified = graphic_kwargs.pop("vmin") - if "vmax" in graphic_kwargs.keys(): - vmax_specified = graphic_kwargs.pop("vmax") + elif not all([isinstance(c, str) for c in cmap]): + raise TypeError(f"`cmap` must be a or a list/tuple of ") self._figure: Figure = Figure(**figure_kwargs_default) - self._histogram_widget = histogram_widget - for data_ix, (d, subplot) in enumerate(zip(self.data, self.figure)): + self._indices = Indices(list(0 for i in range(self.n_sliders)), self) + + for i, subplot in zip(range(len(self._image_processors)), self.figure): + image_data = self._get_image( + self._image_processors[i], tuple(self._indices) + ) + + if image_data is None: + # this subplot/data array is blank, skip + continue - frame = self._process_indices(d, slice_indices=self._current_index) - frame = self._process_frame_apply(frame, data_ix) + # next 20 lines are just vmin, vmax parsing + vmin_specified, vmax_specified = None, None + if "vmin" in graphic_kwargs[i].keys(): + vmin_specified = graphic_kwargs[i].pop("vmin") + if "vmax" in graphic_kwargs[i].keys(): + vmax_specified = graphic_kwargs[i].pop("vmax") if (vmin_specified is None) or (vmax_specified is None): # if either vmin or vmax are not specified, calculate an estimate by subsampling - vmin_estimate, vmax_estimate = quick_min_max(d) + vmin_estimate, vmax_estimate = quick_min_max( + self._image_processors[i].data + ) # decide vmin, vmax passed to ImageGraphic constructor based on whether it's user specified or now if vmin_specified is None: @@ -552,272 +377,550 @@ def __init__( # both vmin and vmax are specified vmin, vmax = vmin_specified, vmax_specified - ig = ImageGraphic( - frame, - name="image_widget_managed", - vmin=vmin, - vmax=vmax, - **graphic_kwargs, - ) - subplot.add_graphic(ig) + graphic_kwargs[i]["cmap"] = cmap[i] + + if self._image_processors[i].n_display_dims == 2: + # create an Image + graphic = ImageGraphic( + data=image_data, + name="image_widget_managed", + vmin=vmin, + vmax=vmax, + **graphic_kwargs[i], + ) + elif self._image_processors[i].n_display_dims == 3: + # create an ImageVolume + graphic = ImageVolumeGraphic( + data=image_data, + name="image_widget_managed", + vmin=vmin, + vmax=vmax, + **graphic_kwargs[i], + ) + subplot.camera.fov = 50 - if self._histogram_widget: - hlut = HistogramLUTTool(data=d, images=ig, name="histogram_lut") + subplot.add_graphic(graphic) - subplot.docks["right"].add_graphic(hlut) - subplot.docks["right"].size = 80 - subplot.docks["right"].auto_scale(maintain_aspect=False) - subplot.docks["right"].controller.enabled = False - - # hard code the expected height so that the first render looks right in tests, docs etc. - if len(self.slider_dims) == 0: - ui_size = 57 - if len(self.slider_dims) == 1: - ui_size = 106 - elif len(self.slider_dims) == 2: - ui_size = 155 - - self._image_widget_sliders = ImageWidgetSliders( + self._reset_histogram(subplot, self._image_processors[i]) + + self._sliders_ui = ImageWidgetSliders( figure=self.figure, - size=ui_size, + size=57 + (IMGUI_SLIDER_HEIGHT * self.n_sliders), location="bottom", title="ImageWidget Controls", image_widget=self, ) - self.figure.add_gui(self._image_widget_sliders) + self.figure.add_gui(self._sliders_ui) - self._current_index_changed_handlers = set() + self._indices_changed_handlers = set() self._reentrant_block = False - self._initialized = True + @property + def data(self) -> ImageWidgetProperty[ArrayProtocol | None]: + """get or set the nd-image data arrays""" + return self._data + + @data.setter + def data(self, new_data: Sequence[ArrayProtocol | None]): + if isinstance(new_data, ArrayProtocol) or new_data is None: + new_data = [new_data] * len(self._image_processors) + + if len(new_data) != len(self._image_processors): + raise IndexError + + # if the data array hasn't been changed + # graphics will not be reset for this data index + skip_indices = list() + + for i, (new_data, image_processor) in enumerate( + zip(new_data, self._image_processors) + ): + if new_data is image_processor.data: + skip_indices.append(i) + continue + + image_processor.data = new_data + + self._reset(skip_indices) + + @property + def rgb(self) -> ImageWidgetProperty[bool]: + """get or set the rgb toggle for each data array""" + return self._rgb + + @rgb.setter + def rgb(self, rgb: Sequence[bool]): + if isinstance(rgb, bool): + rgb = [rgb] * len(self._image_processors) + + if len(rgb) != len(self._image_processors): + raise IndexError + + # if the rgb option hasn't been changed + # graphics will not be reset for this data index + skip_indices = list() + + for i, (new, image_processor) in enumerate(zip(rgb, self._image_processors)): + if image_processor.rgb == new: + skip_indices.append(i) + continue + + image_processor.rgb = new + + self._reset(skip_indices) @property - def frame_apply(self) -> dict | None: - return self._frame_apply + def n_display_dims(self) -> ImageWidgetProperty[Literal[2, 3]]: + """Get or set the number of display dimensions for each data array, 2 is a 2D image, 3 is a 3D volume image""" + return self._n_display_dims + + @n_display_dims.setter + def n_display_dims(self, new_ndd: Sequence[Literal[2, 3]] | Literal[2, 3]): + if isinstance(new_ndd, (int, np.integer)): + if new_ndd == 2 or new_ndd == 3: + new_ndd = [new_ndd] * len(self._image_processors) + else: + raise ValueError + + if len(new_ndd) != len(self._image_processors): + raise IndexError + + if not all([(n == 2) or (n == 3) for n in new_ndd]): + raise ValueError - @frame_apply.setter - def frame_apply(self, frame_apply: dict[int, callable]): - if frame_apply is None: - frame_apply = dict() + # if the n_display_dims hasn't been changed for this data array + # graphics will not be reset for this data array index + skip_indices = list() + + # first update image arrays + for i, (image_processor, new) in enumerate( + zip(self._image_processors, new_ndd) + ): + if new > image_processor.max_n_display_dims: + raise IndexError( + f"number of display dims exceeds maximum number of possible " + f"display dimensions: {image_processor.max_n_display_dims}, for array at index: " + f"{i} with shape: {image_processor.shape}, and rgb set to: {image_processor.rgb}" + ) + + if image_processor.n_display_dims == new: + skip_indices.append(i) + else: + image_processor.n_display_dims = new - self._frame_apply = frame_apply - # force update image graphic - self.current_index = self.current_index + self._reset(skip_indices) @property - def window_funcs(self) -> dict[str, _WindowFunctions]: + def window_funcs(self) -> ImageWidgetProperty[tuple[WindowFuncCallable | None] | None]: + """get or set the window functions""" + return self._window_funcs + + @window_funcs.setter + def window_funcs(self, new_funcs: Sequence[WindowFuncCallable | None] | None): + if callable(new_funcs) or new_funcs is None: + new_funcs = [new_funcs] * len(self._image_processors) + + if len(new_funcs) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("window_funcs", new_funcs) + + @property + def window_sizes(self) -> ImageWidgetProperty[tuple[int | None, ...] | None]: + """get or set the window sizes""" + return self._window_sizes + + @window_sizes.setter + def window_sizes( + self, new_sizes: Sequence[tuple[int | None, ...] | int | None] | int | None + ): + if isinstance(new_sizes, int) or new_sizes is None: + # same window for all data arrays + new_sizes = [new_sizes] * len(self._image_processors) + + if len(new_sizes) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("window_sizes", new_sizes) + + @property + def window_order(self) -> ImageWidgetProperty[tuple[int, ...] | None]: + """get or set order in which window functions are applied over dimensions""" + return self._window_order + + @window_order.setter + def window_order(self, new_order: Sequence[tuple[int, ...]]): + if new_order is None: + new_order = [new_order] * len(self._image_processors) + + if all([isinstance(order, (int, np.integer))] for order in new_order): + # same order specified across all data arrays + new_order = [new_order] * len(self._image_processors) + + if len(new_order) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("window_order", new_order) + + @property + def spatial_func(self) -> ImageWidgetProperty[Callable | None]: + """Get or set a spatial_func that operates on the spatial dimensions of the 2D or 3D image""" + return self._spatial_func + + @spatial_func.setter + def spatial_func(self, funcs: Callable | Sequence[Callable] | None): + if callable(funcs) or funcs is None: + funcs = [funcs] * len(self._image_processors) + + if len(funcs) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("spatial_func", funcs) + + def _set_image_processor_funcs(self, attr, new_values): + """sets window_funcs, window_sizes, window_order, or spatial_func and updates displayed data and histograms""" + for new, image_processor, subplot in zip( + new_values, self._image_processors, self.figure + ): + if getattr(image_processor, attr) == new: + continue + + setattr(image_processor, attr, new) + + # window functions and spatial functions will only change the histogram + # they do not change the collections of dimensions, so we don't need to call _reset_dimensions + # they also do not change the image graphic, so we do not need to call _reset_image_graphics + self._reset_histogram(subplot, image_processor) + + # update the displayed image data in the graphics + self.indices = self.indices + + @property + def indices(self) -> ImageWidgetProperty[int]: """ - Get or set the window functions + Get or set the current indices. Returns ------- - Dict[str, _WindowFunctions] + indices: ImageWidgetProperty[int] + integer index for each slider dimension """ - return self._window_funcs + return self._indices - @window_funcs.setter - def window_funcs(self, callable_dict: dict[str, int]): - if callable_dict is None: - self._window_funcs = None - # force frame to update - self.current_index = self.current_index + @indices.setter + def indices(self, new_indices: Sequence[int]): + if self._reentrant_block: return - elif isinstance(callable_dict, dict): - if not set(callable_dict.keys()).issubset(ALLOWED_WINDOW_DIMS): - raise ValueError( - f"The only allowed keys to window funcs are {list(ALLOWED_WINDOW_DIMS)} " - f"Your window func passed in these keys: {list(callable_dict.keys())}" + try: + self._reentrant_block = True # block re-execution until new_indices has *fully* completed execution + + if len(new_indices) != self.n_sliders: + raise IndexError( + f"len(new_indices) != ImageWidget.n_sliders, {len(new_indices)} != {self.n_sliders}. " + f"The length of the new_indices must be the same as the number of sliders" ) - if not all( - [ - isinstance(_callable_dict, tuple) - for _callable_dict in callable_dict.values() - ] - ): - raise TypeError( - "dict argument to `window_funcs` must be in the form of: " - "`{dimension: (func, window_size)}`. " - "See the docstring." + + if any([i < 0 for i in new_indices]): + raise IndexError( + f"only positive index values are supported, you have passed: {new_indices}" ) - for v in callable_dict.values(): - if not callable(v[0]): - raise TypeError( - "dict argument to `window_funcs` must be in the form of: " - "`{dimension: (func, window_size)}`. " - "See the docstring." - ) - if not isinstance(v[1], int): - raise TypeError( - f"dict argument to `window_funcs` must be in the form of: " - "`{dimension: (func, window_size)}`. " - f"where window_size is integer. you passed in {v[1]} for window_size" - ) - - if not isinstance(self._window_funcs, dict): - self._window_funcs = dict() - - for k in list(callable_dict.keys()): - self._window_funcs[k] = _WindowFunctions(self, *callable_dict[k]) - else: + for image_processor, graphic in zip(self._image_processors, self.graphics): + new_data = self._get_image(image_processor, indices=new_indices) + if new_data is None: + continue + + graphic.data = new_data + + self._indices._fpl_set(new_indices) + + # call any event handlers + for handler in self._indices_changed_handlers: + handler(tuple(self.indices)) + + except Exception as exc: + # raise original exception + raise exc # indices setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._reentrant_block = False + + @property + def histogram_widget(self) -> bool: + """show or hide the histograms""" + return self._histogram_widget + + @histogram_widget.setter + def histogram_widget(self, show_histogram: bool): + if not isinstance(show_histogram, bool): raise TypeError( - f"`window_funcs` must be either Nonetype or dict." - f"You have passed a {type(callable_dict)}. See the docstring." + f"`histogram_widget` can be set with a bool, you have passed: {show_histogram}" ) - # force frame to update - self.current_index = self.current_index + for subplot, image_processor in zip(self.figure, self._image_processors): + image_processor.compute_histogram = show_histogram + self._reset_histogram(subplot, image_processor) - def _process_indices( - self, array: np.ndarray, slice_indices: dict[str, int] - ) -> np.ndarray: - """ - Get the 2D array from the given slice indices. If not returning a 2D slice (such as due to window_funcs) - then `frame_apply` must take this output and return a 2D array + @property + def n_sliders(self) -> int: + """number of sliders""" + return max([a.n_slider_dims for a in self._image_processors]) - Parameters - ---------- - array: np.ndarray - array-like to get a 2D slice from + @property + def bounds(self) -> tuple[int, ...]: + """The max bound across all dimensions across all data arrays""" + # initialize with 0 + bounds = [0] * self.n_sliders + + # TODO: implement left -> right slider dims ordering, right now it's only right -> left + # in reverse because dims go left <- right + for i, dim in enumerate(range(-1, -self.n_sliders - 1, -1)): + # across each dim + for array in self._image_processors: + if i > array.n_slider_dims - 1: + continue + # across each data array + # dims go left <- right + bounds[dim] = max(array.slider_dims_shape[dim], bounds[dim]) - slice_indices: Dict[str, int] - dict in form of {dimension_index: current_index} - For example if an array has shape [1000, 30, 512, 512] corresponding to [t, z, x, y]: - To get the 100th timepoint and 3rd z-plane pass: - {"t": 100, "z": 3} + return bounds - Returns - ------- - np.ndarray - array-like, 2D slice + @property + def slider_dim_names(self) -> tuple[str, ...]: + return self._slider_dim_names - """ + @slider_dim_names.setter + def slider_dim_names(self, names: Sequence[str]): + if names is None: + self._slider_dim_names = None + return - data_ix = None - for i in range(len(self.data)): - if self.data[i] is array: - data_ix = i - break + if not all([isinstance(n, str) for n in names]): + raise TypeError(f"`slider_dim_names` must be set with a list/tuple of , you passed: {names}") - numerical_dims = list() + if len(set(names)) != len(names): + raise ValueError( + f"`slider_dim_names` must be unique, you passed: {names}" + ) - # Totally number of dimensions for this specific array - curr_ndim = self.data[data_ix].ndim + self._slider_dim_names = tuple(names) + + def _get_image( + self, image_processor: NDImageProcessor, indices: Sequence[int] + ) -> ArrayProtocol: + """Get a processed 2d or 3d image from the NDImage at the given indices""" + n = image_processor.n_slider_dims + + if self._sliders_dim_order == "right": + return image_processor.get(indices[-n:]) + + elif self._sliders_dim_order == "left": + # TODO: left -> right is not fully implemented yet in ImageWidget + return image_processor.get(indices[:n]) + + def _reset_dimensions(self): + """reset the dimensions w.r.t. current collection of NDImageProcessors""" + # TODO: implement left -> right slider dims ordering, right now it's only right -> left + # add or remove dims from indices + # trim any excess dimensions + while len(self._indices) > self.n_sliders: + # remove outer most dims first + self._indices.pop_dim() + self._sliders_ui.pop_dim() + + # add any new dimensions that aren't present + while len(self.indices) < self.n_sliders: + # insert right -> left + self._indices.push_dim() + self._sliders_ui.push_dim() + + self._sliders_ui.size = 57 + (IMGUI_SLIDER_HEIGHT * self.n_sliders) + + def _reset_image_graphics(self, subplot, image_processor): + """delete and create a new image graphic if necessary""" + new_image = self._get_image(image_processor, indices=tuple(self.indices)) + if new_image is None: + if "image_widget_managed" in subplot: + # delete graphic from this subplot if present + subplot.delete_graphic(subplot["image_widget_managed"]) + # skip this subplot + return - # Initialize slices for each dimension of array - indexer = [slice(None)] * curr_ndim + # check if a graphic exists + if "image_widget_managed" in subplot: + # create a new graphic only if the Texture buffer shape doesn't match + if subplot["image_widget_managed"].data.value.shape == new_image.shape: + return + + # keep cmap + cmap = subplot["image_widget_managed"].cmap + if cmap is None: + # ex: going from rgb -> grayscale + cmap = "plasma" + # delete graphic since it will be replaced + subplot.delete_graphic(subplot["image_widget_managed"]) + else: + # default cmap + cmap = "plasma" - # Maps from n_scrollable_dims to one of "", "t", "tz", etc. - curr_scrollable_format = SCROLLABLE_DIMS_ORDER[self.n_scrollable_dims[data_ix]] - for dim in list(slice_indices.keys()): - if dim not in curr_scrollable_format: - continue - # get axes order for that specific array - numerical_dim = curr_scrollable_format.index(dim) + if image_processor.n_display_dims == 2: + g = subplot.add_image( + data=new_image, cmap=cmap, name="image_widget_managed" + ) - indices_dim = slice_indices[dim] + # set camera orthogonal to the xy plane, flip y axis + subplot.camera.set_state( + { + "position": [0, 0, -1], + "rotation": [0, 0, 0, 1], + "scale": [1, -1, 1], + "reference_up": [0, 1, 0], + "fov": 0, + "depth_range": None, + } + ) - # takes care of index selection (window slicing) for this specific axis - indices_dim = self._get_window_indices(data_ix, numerical_dim, indices_dim) + subplot.controller = "panzoom" + subplot.axes.intersection = None + subplot.auto_scale() - # set the indices for this dimension - indexer[numerical_dim] = indices_dim + elif image_processor.n_display_dims == 3: + g = subplot.add_image_volume( + data=new_image, cmap=cmap, name="image_widget_managed" + ) + subplot.camera.fov = 50 + subplot.controller = "orbit" - numerical_dims.append(numerical_dim) + # make sure all 3D dimension camera scales are positive + # MIP rendering doesn't work with negative camera scales + for dim in ["x", "y", "z"]: + if getattr(subplot.camera.local, f"scale_{dim}") < 0: + setattr(subplot.camera.local, f"scale_{dim}", 1) - # apply indexing to the array - # use window function is given for this dimension - if self.window_funcs is not None: - a = array - for i, dim in enumerate(sorted(numerical_dims)): - dim_str = curr_scrollable_format[dim] - dim = dim - i # since we loose a dimension every iteration - _indexer = [slice(None)] * (curr_ndim - i) - _indexer[dim] = indexer[dim + i] + subplot.auto_scale() - # if the indexer is an int, this dim has no window func - if isinstance(_indexer[dim], int): - a = a[tuple(_indexer)] - else: - # if the indices are from `self._get_window_indices` - func = self.window_funcs[dim_str].func - window = a[tuple(_indexer)] - a = func(window, axis=dim) - return a - else: - return array[tuple(indexer)] + def _reset_histogram(self, subplot, image_processor): + """reset the histogram""" + if not self._histogram_widget: + subplot.docks["right"].size = 0 + return + + if image_processor.histogram is None: + # no histogram available for this processor + # either there is no data array in this subplot, + # or a histogram routine does not exist for this processor + subplot.docks["right"].size = 0 + return - def _get_window_indices(self, data_ix, dim, indices_dim): - if self.window_funcs is None: - return indices_dim + if "image_widget_managed" not in subplot: + # no image in this subplot + subplot.docks["right"].size = 0 + return + + image = subplot["image_widget_managed"] + + if "histogram_lut" in subplot.docks["right"]: + hlut: HistogramLUTTool = subplot.docks["right"]["histogram_lut"] + hlut.histogram = image_processor.histogram + hlut.images = image + if subplot.docks["right"].size < 1: + subplot.docks["right"].size = 80 else: - ix = indices_dim + # need to make one + hlut = HistogramLUTTool( + histogram=image_processor.histogram, + images=image, + name="histogram_lut", + ) - dim_str = SCROLLABLE_DIMS_ORDER[self.n_scrollable_dims[data_ix]][dim] + subplot.docks["right"].add_graphic(hlut) + subplot.docks["right"].size = 80 - # if no window stuff specified for this dim - if dim_str not in self.window_funcs.keys(): - return indices_dim + self.reset_vmin_vmax() - # if window stuff is set to None for this dim - # example: {"t": None} - if self.window_funcs[dim_str] is None: - return indices_dim + def _reset(self, skip_data_indices: tuple[int, ...] = None): + if skip_data_indices is None: + skip_data_indices = tuple() - window_size = self.window_funcs[dim_str].window_size + # reset the slider indices according to the new collection of dimensions + self._reset_dimensions() + # update graphics where display dims have changed accordings to indices + for i, (subplot, image_processor) in enumerate( + zip(self.figure, self._image_processors) + ): + if i in skip_data_indices: + continue - if (window_size == 0) or (window_size is None): - return indices_dim + self._reset_image_graphics(subplot, image_processor) + self._reset_histogram(subplot, image_processor) - half_window = int((window_size - 1) / 2) # half-window size - # get the max bound for that dimension - max_bound = self._dims_max_bounds[dim_str] - indices_dim = range( - max(0, ix - half_window), min(max_bound, ix + half_window) - ) - return indices_dim + # force an update + self.indices = self.indices - def _process_frame_apply(self, array, data_ix) -> np.ndarray: - if callable(self._frame_apply): - return self._frame_apply(array) + @property + def figure(self) -> Figure: + """ + ``Figure`` used by `ImageWidget`. + """ + return self._figure - if data_ix not in self._frame_apply.keys(): - return array + @property + def graphics(self) -> list[ImageGraphic]: + """List of ``ImageWidget`` managed graphics.""" + iw_managed = list() + for subplot in self.figure: + if "image_widget_managed" in subplot: + iw_managed.append(subplot["image_widget_managed"]) + else: + iw_managed.append(None) + return tuple(iw_managed) - elif self._frame_apply[data_ix] is not None: - return self._frame_apply[data_ix](array) + @property + def cmap(self) -> tuple[str | None, ...]: + """get the cmaps, or set the cmap across all images""" + return tuple(g.cmap for g in self.graphics) - return array + @cmap.setter + def cmap(self, name: str): + for g in self.graphics: + if g is None: + # no data at this index + continue - def add_event_handler(self, handler: callable, event: str = "current_index"): + if g.cmap is None: + # if rgb + continue + + g.cmap = name + + def add_event_handler(self, handler: callable, event: str = "indices"): """ Register an event handler. - Currently the only event that ImageWidget supports is "current_index". This event is - emitted whenever the index of the ImageWidget changes. + Currently the only event that ImageWidget supports is "indices". This event is + emitted whenever the indices of the ImageWidget changes. Parameters ---------- handler: callable - callback function, must take a dict as the only argument. This dict will be the `current_index` + callback function, must take a tuple of int as the only argument. This tuple will be the `indices` - event: str, "current_index" - the only supported event is "current_index" + event: str, "indices" + the only supported event is "indices" Example ------- .. code-block:: py - def my_handler(index): - print(index) - # example prints: {"t": 100} if data has only time dimension - # "z" index will be another key if present in the data, ex: {"t": 100, "z": 5} + def my_handler(indices): + print(indices) + # example prints: (100, 15) if the data has 2 slider dimensions with sliders at positions 100, 15 # create an image widget iw = ImageWidget(...) @@ -826,30 +929,36 @@ def my_handler(index): iw.add_event_handler(my_handler) """ - if event != "current_index": - raise ValueError( - "`current_index` is the only event supported by `ImageWidget`" - ) + if event != "indices": + raise ValueError("`indices` is the only event supported by `ImageWidget`") - self._current_index_changed_handlers.add(handler) + self._indices_changed_handlers.add(handler) def remove_event_handler(self, handler: callable): """Remove a registered event handler""" - self._current_index_changed_handlers.remove(handler) + self._indices_changed_handlers.remove(handler) def clear_event_handlers(self): """Clear all registered event handlers""" - self._current_index_changed_handlers.clear() + self._indices_changed_handlers.clear() def reset_vmin_vmax(self): """ Reset the vmin and vmax w.r.t. the full data """ - for data, subplot in zip(self.data, self.figure): + for image_processor, subplot in zip(self._image_processors, self.figure): if "histogram_lut" not in subplot.docks["right"]: continue + + if image_processor.histogram is None: + continue + hlut = subplot.docks["right"]["histogram_lut"] - hlut.set_data(data, reset_vmin_vmax=True) + hlut.histogram = image_processor.histogram + + edges = image_processor.histogram[1] + + hlut.vmin, hlut.vmax = edges[0], edges[-1] def reset_vmin_vmax_frame(self): """ @@ -857,130 +966,21 @@ def reset_vmin_vmax_frame(self): ImageGraphic instead of the data in the full data array. For example, if a post-processing function is used, the range of values in the ImageGraphic can be very different from the range of values in the full data array. - - TODO: We could think of applying the frame_apply funcs to a subsample of the entire array to get a better estimate of vmin vmax? """ - for subplot in self.figure: + for subplot, image_processor in zip(self.figure, self._image_processors): if "histogram_lut" not in subplot.docks["right"]: continue - hlut = subplot.docks["right"]["histogram_lut"] - # set the data using the current image graphic data - hlut.set_data(subplot["image_widget_managed"].data.value) - - def set_data( - self, - new_data: np.ndarray | list[np.ndarray], - reset_vmin_vmax: bool = True, - reset_indices: bool = True, - ): - """ - Change data of widget. Note: sliders max currently update only for ``txy`` and ``tzxy`` data. - - Parameters - ---------- - new_data: array-like or list of array-like - The new data to display in the widget - - reset_vmin_vmax: bool, default ``True`` - reset the vmin vmax levels based on the new data - - reset_indices: bool, default ``True`` - reset the current index for all dimensions to 0 - - """ - - if reset_indices: - for key in self.current_index: - self.current_index[key] = 0 - - # set slider max according to new data - max_lengths = dict() - for scroll_dim in self.slider_dims: - max_lengths[scroll_dim] = np.inf - - if _is_arraylike(new_data): - new_data = [new_data] - - if len(self._data) != len(new_data): - raise ValueError( - f"number of new data arrays {len(new_data)} must match" - f" current number of data arrays {len(self._data)}" - ) - # check all arrays - for i, (new_array, current_array) in enumerate(zip(new_data, self._data)): - if new_array.ndim != current_array.ndim: - raise ValueError( - f"new data ndim {new_array.ndim} at index {i} " - f"does not equal current data ndim {current_array.ndim}" - ) - - # Computes the number of scrollable dims and also validates new_array - new_scrollable_dims = self._get_n_scrollable_dims(new_array, self._rgb[i]) - - if self.n_scrollable_dims[i] != new_scrollable_dims: - raise ValueError( - f"number of dimensions of data arrays must match number of dimensions of " - f"existing data arrays" - ) - - # if checks pass, update with new data - for i, (new_array, current_array, subplot) in enumerate( - zip(new_data, self._data, self.figure) - ): - # if the new array is the same as the existing array, skip - # this allows setting just a subset of the arrays in the ImageWidget - if new_data is self._data[i]: + if image_processor.histogram is None: continue - # check last two dims (x and y) to see if data shape is changing - old_data_shape = self._data[i].shape[-self.n_img_dims[i] :] - self._data[i] = new_array - - if old_data_shape != new_array.shape[-self.n_img_dims[i] :]: - frame = self._process_indices( - new_array, slice_indices=self._current_index - ) - frame = self._process_frame_apply(frame, i) - - # make new graphic first - new_graphic = ImageGraphic(data=frame, name="image_widget_managed") - - if self._histogram_widget: - # set hlut tool to use new graphic - subplot.docks["right"]["histogram_lut"].images = new_graphic - - # delete old graphic after setting hlut tool to new graphic - # this ensures gc - subplot.delete_graphic(graphic=subplot["image_widget_managed"]) - subplot.insert_graphic(graphic=new_graphic) - - # Returns "", "t", or "tz" - curr_scrollable_format = SCROLLABLE_DIMS_ORDER[self.n_scrollable_dims[i]] - - for scroll_dim in self.slider_dims: - if scroll_dim in curr_scrollable_format: - new_length = new_array.shape[ - curr_scrollable_format.index(scroll_dim) - ] - if max_lengths[scroll_dim] == np.inf: - max_lengths[scroll_dim] = new_length - elif max_lengths[scroll_dim] != new_length: - raise ValueError( - f"New arrays have differing values along dim {scroll_dim}" - ) - - self._dims_max_bounds[scroll_dim] = max_lengths[scroll_dim] - - # set histogram widget - if self._histogram_widget: - subplot.docks["right"]["histogram_lut"].set_data( - new_array, reset_vmin_vmax=reset_vmin_vmax - ) - - # force graphics to update - self.current_index = self.current_index + hlut = subplot.docks["right"]["histogram_lut"] + # set the data using the current image graphic data + image = subplot["image_widget_managed"] + freqs, edges = np.histogram(image.data.value, bins=100) + hlut.histogram = (freqs, edges) + hlut.vmin, hlut.vmax = edges[0], edges[-1] def show(self, **kwargs): """ @@ -990,7 +990,7 @@ def show(self, **kwargs): ---------- kwargs: Any - passed to `Figure.show()` + passed to `Figure.show()`t Returns -------