Source code for sphero_vem.preprocessing

"""
Functions for preprocessing images.
"""

import warnings
from pathlib import Path
from typing import Literal
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import resize
import zarr
import dask
import dask.array as da
import dask_image
import dask_image.ndinterp
from dask.diagnostics import ProgressBar
from sphero_vem.utils import dirname_from_spacing, create_ome_multiscales


[docs] def create_pyramid( image: torch.Tensor, num_levels: int, factor: int ) -> list[torch.tensor]: """Build a multi-resolution image pyramid. Parameters ---------- image : torch.Tensor Input image tensor at full resolution. num_levels : int Total number of pyramid levels, including the full-resolution image. factor : int Downsampling factor between consecutive levels. Returns ------- list[torch.Tensor] List of image tensors ordered from coarsest to finest resolution. """ pyramid = [image] for _ in range(num_levels - 1): image = resize(image, image.shape[-1] // factor) pyramid.append(image) return list(reversed(pyramid))
[docs] def downscale_tensor( image: torch.Tensor, factor: int, mode: str = "bilinear" ) -> torch.tensor: """Downscale a tensor or batch of tensors by an integer factor. Parameters ---------- image : torch.Tensor Input tensor of shape ``(..., H, W)``. Unsqueezed to 4-D internally if necessary before interpolation. factor : int Integer downsampling factor. Output spatial dimensions are ``H // factor`` × ``W // factor``. mode : str, optional Interpolation mode passed to ``torch.nn.functional.interpolate``. Default is ``"bilinear"``. Use ``"nearest"`` for label maps. Returns ------- torch.Tensor Downscaled tensor with the same number of dimensions as the input. """ n_dim = image.dim() while image.dim() < 4: image = image.unsqueeze(0) if mode == "nearest": image_ds: torch.Tensor = F.interpolate( image, scale_factor=1 / factor, mode=mode, ) else: image_ds: torch.Tensor = F.interpolate( image, scale_factor=1 / factor, mode=mode, align_corners=False, antialias=True, ) # Make sure that output has the same number of dimensions of input while image_ds.dim() > n_dim: image_ds = image_ds.squeeze(0) return image_ds
[docs] def resample_array( zarr_path: Path, array_path: str, target_spacing: tuple[int, int, int], order: int = 1, zarr_chunks: tuple[int, int, int] = (1, 1024, 1024), n_workers: int = 4, ) -> None: """Resample an array in a Zarr archive to the target voxel spacing. Uses a lazy Gaussian pre-blur followed by affine transform via dask_image, keeping memory usage bounded to chunk size throughout. Anti-aliasing is applied only along downsampled axes, mirroring skimage.transform.resize. Integer label data (integer dtype + order=0) is resampled without anti-aliasing. float16 arrays are promoted to float32 for processing and cast back on output, as scipy.ndimage does not support float16. Parameters ---------- zarr_path : Path Path to the Zarr archive. array_path : str Path to the source array within the archive. target_spacing : tuple[int, int, int] Target voxel spacing (Z, Y, X) in nanometers. order : int Spline interpolation order. 0 = nearest neighbour (labels), 1 = linear (images). Default 1. zarr_chunks : tuple[int, int, int] Chunk shape for the output Zarr array. n_workers : int Number of threads for dask's threaded scheduler. Default 4. """ Z_SCALE_WARN_THRESHOLD = 2.5 root = zarr.open_group(zarr_path) src_array: zarr.Array = root.get(array_path) if src_array is None: raise FileNotFoundError( f"Source array {array_path} not found under {zarr_path}" ) src_dtype = np.dtype(src_array.dtype) original_shape = np.array(src_array.shape) src_spacing = np.asarray(src_array.attrs["spacing"]) spacing_dir = dirname_from_spacing(target_spacing) ref_array: zarr.Array = root.get(f"images/{spacing_dir}") if ref_array is not None: target_shape = list(ref_array.shape) else: ratio = src_spacing / np.asarray(target_spacing) target_shape = (ratio * original_shape).astype(int).tolist() scale_factors = original_shape / np.array(target_shape) z_magnitude = max(scale_factors[0], 1.0 / scale_factors[0]) if z_magnitude > Z_SCALE_WARN_THRESHOLD: direction = "downsampling" if scale_factors[0] > 1 else "upsampling" warnings.warn( f"Large Z {direction} factor {z_magnitude:.2f}x " f"(src spacing: {src_spacing[0]}, target: {target_spacing[0]}). " f"Gaussian pre-blur operates per-chunk; boundary artifacts may " f"appear at chunk edges for large kernels.", UserWarning, stacklevel=2, ) is_label = np.issubdtype(src_dtype, np.integer) and order == 0 working_dtype = ( np.dtype("float32") if src_dtype == np.dtype("float16") else src_dtype ) temp_chunks = (8, 1024, 1024) src_dask: da.Array = da.from_zarr(src_array).rechunk(temp_chunks) if working_dtype != src_dtype: src_dask = src_dask.astype(working_dtype) aa_sigmas = np.zeros(3) if not is_label: aa_sigmas = np.maximum(0.0, (scale_factors - 1.0) / 2.0) if np.any(aa_sigmas > 0): src_dask = dask_image.ndfilters.gaussian_filter( src_dask, sigma=aa_sigmas.tolist() ) resampled_dask: da.Array = dask_image.ndinterp.affine_transform( src_dask, matrix=np.diag(scale_factors), output_shape=target_shape, output_chunks=temp_chunks, order=order, mode="nearest", ) if working_dtype != src_dtype: resampled_dask = resampled_dask.astype(src_dtype) parent_group: zarr.Group = root.get(str(Path(src_array.path).parent)) dst_zarr_path = f"{parent_group.path}/{spacing_dir}" dst_zarr = root.require_array( name=dst_zarr_path, shape=target_shape, chunks=zarr_chunks, dtype=src_dtype, compressors=src_array.compressors, overwrite=True, ) with ProgressBar(), dask.config.set(num_workers=n_workers): resampled_dask.to_zarr(dst_zarr, overwrite=True) dst_zarr.attrs["spacing"] = target_spacing dst_zarr.attrs["processing"] = src_array.attrs.get("processing", []) + [ { "step": "resample", "order": order, "scale_factors": scale_factors.tolist(), "anti_aliasing": not is_label, "anti_aliasing_sigma": aa_sigmas.tolist(), } ] dst_zarr.attrs["inputs"] = src_array.path create_ome_multiscales(parent_group)
[docs] def rechunk_array( root: zarr.Group, src_array_path: str, dst_array_path: str, dst_chunks: tuple[int, int, int] = (1, 1024, 1024), copy_attributes: bool = True, delete_src: bool = False, verbose: bool = True, ) -> zarr.Array: """Copy a Zarr array to a new path with a different chunk layout. Parameters ---------- root : zarr.Group Root Zarr group containing the source array. src_array_path : str Path to the source array within *root*. dst_array_path : str Path for the destination array within *root*. Created or overwritten. dst_chunks : tuple[int, int, int], optional Chunk shape for the output array. Default is ``(1, 1024, 1024)``. copy_attributes : bool, optional If True, copy all Zarr attributes from source to destination. Default is True. delete_src : bool, optional If True, delete the source array after copying. Default is False. verbose : bool, optional If True, show a tqdm progress bar. Default is True. Returns ------- zarr.Array The newly created destination array. Raises ------ FileNotFoundError If *src_array_path* does not exist within *root*. """ src_zarr: zarr.Array = root.get(src_array_path) if src_zarr is None: raise FileNotFoundError(f"Temp array {src_array_path} not found") compressor = src_zarr.compressors dst_zarr = root.require_array( name=dst_array_path, shape=src_zarr.shape, chunks=dst_chunks, dtype=src_zarr.dtype, compressors=compressor, overwrite=True, ) for z in tqdm(range(src_zarr.shape[0]), disable=not verbose): dst_zarr[z] = src_zarr[z] if copy_attributes: dst_zarr.attrs.clear() dst_zarr.attrs.update(dict(src_zarr.attrs)) if delete_src: del root[src_array_path] return dst_zarr
[docs] def crop_to_valid( data: np.ndarray, mode: Literal["nonzero", "notnan"] = "nonzero" ) -> np.ndarray: """ Crop a 3D array to the bounding box of valid data. Parameters ---------- data : np.ndarray The 3D input array. mode : Literal["nonzero", "notnan"], optional The validity criteria: "nonzero" (default) or "notnan". Returns ------- np.ndarray The cropped array. Raises ------ ValueError If mode is not a valid value. """ if mode == "notnan": mask = ~np.isnan(data) elif mode == "nonzero": mask = data != 0 else: raise ValueError( f"Mode {mode} not recognized. Valid options are 'nonzero' and 'notnan'" ) coords = np.argwhere(mask) if coords.size == 0: return data start = coords.min(axis=0) stop = coords.max(axis=0) + 1 slicer = tuple(slice(s, e) for s, e in zip(start, stop)) return data[slicer]