Source code for sphero_vem.segmentation.cellpose.utils

"""
Various utility functions for Cellpose segmentation
"""

from pathlib import Path
import numpy as np
from scipy.special import expit
import pandas as pd
from cellpose import metrics
import zarr
from skimage import graph
from sphero_vem.io import write_zarr
from sphero_vem.preprocessing import resample_array
from sphero_vem.utils import dirname_from_spacing
from sphero_vem.utils.accelerator import (
    xp,
    ndi,
    ArrayLike,
    gpu_dispatch,
    to_host,
)


@gpu_dispatch(return_to_host=True)
def _upsample_seeds(
    labels_lr: ArrayLike, erosion_iterations: int, zoom_factors: ArrayLike
) -> np.ndarray:
    """Zoom low-res labels to high-res seeds, with erosion to prevent bleed on zoom.

    Parameters
    ----------
    labels_lr : ArrayLike
        Low-resolution integer label array.
    erosion_iterations : int
        Number of binary erosion steps applied before zooming.
    zoom_factors : ArrayLike
        Per-axis zoom factors.

    Returns
    -------
    np.ndarray
        High-resolution seed array.
    """
    eroded = labels_lr > 0
    for _ in range(erosion_iterations):
        eroded = ndi.binary_erosion(eroded)
    seeds_lr = labels_lr * eroded
    return ndi.zoom(seeds_lr, zoom_factors, order=0)


@gpu_dispatch(return_to_host=True)
def _calc_foreground(
    cellprob_hr: ArrayLike,
    cellprob_threshold: float,
) -> np.ndarray:
    """Threshold a high-resolution cellprob map to produce a binary foreground mask.

    Parameters
    ----------
    cellprob_hr : ArrayLike
        High-resolution cellprob logit array.
    cellprob_threshold : float
        Logit value above which a voxel is considered foreground.

    Returns
    -------
    numpy.ndarray
        Boolean foreground mask with the same shape as *cellprob_hr*.
    """
    foreground_hr = cellprob_hr > cellprob_threshold
    return foreground_hr


[docs] @gpu_dispatch(return_to_host=True) def region_fill( seeds: ArrayLike, foreground: ArrayLike, max_expansion_steps: int = 10 ) -> np.ndarray: """Expand integer seeds by iterative grey dilation, constrained to a foreground mask. Each iteration expands labels by 1 voxel (6-connectivity). Background regions act as hard barriers since expansion is strictly limited to the foreground mask. Parameters ---------- seeds : ArrayLike Integer label array (0 = background). foreground : ArrayLike Boolean mask. Expansion is strictly limited to True voxels. max_expansion_steps : int, optional Maximum dilation iterations. Default 50. Returns ------- np.ndarray Expanded label array, zero outside foreground. """ seeds = xp.copy(seeds) structure = ndi.generate_binary_structure(rank=seeds.ndim, connectivity=1) dilation_buffer = xp.empty_like(seeds) for _ in range(max_expansion_steps): ndi.grey_dilation(seeds, footprint=structure, output=dilation_buffer) should_fill = (seeds == 0) & foreground & (dilation_buffer > 0) if not xp.any(should_fill): break seeds[should_fill] = dilation_buffer[should_fill] seeds *= foreground return seeds
def _upsample_masks_region_fill( labels_lr: np.ndarray, cellprob_hr: np.ndarray, erosion_iterations: int = 2, cellprob_threshold: float = 0.0, max_expansion_steps: int = 50, ) -> np.ndarray: """Upsample Cellpose labels using region fill constrained by cellprob logits. Parameters ---------- labels_lr : np.ndarray Low-resolution integer label volume. cellprob_hr : np.ndarray High-resolution cellprob logit volume. erosion_iterations : int, optional Erosion steps before zooming seeds. Default 2. cellprob_threshold : float, optional Logit threshold for foreground mask. Default 0.0. max_expansion_steps : int, optional Maximum dilation iterations. Default 50. Returns ------- np.ndarray Upsampled label volume, same shape as `cellprob_hr`. """ zoom_factors = np.array(cellprob_hr.shape) / np.array(labels_lr.shape) seeds = _upsample_seeds(labels_lr, erosion_iterations, zoom_factors) foreground = cellprob_hr > cellprob_threshold return region_fill(seeds, foreground, max_expansion_steps).astype(labels_lr.dtype)
[docs] def upsample_masks( root_path: Path, seg_target: str, target_spacing: tuple[int, float], src_spacing: tuple[int, int, int] = (100, 100, 100), erosion_iterations: int = 2, cellprob_threshold: float = 0.0, store_chunks: tuple[int] | None = None, label_root: str | None = None, n_workers: int = 4, ) -> None: """Upsample low-resolution Cellpose labels to a higher-resolution target spacing. Reads low-resolution masks and the corresponding high-resolution cellprob array from a Zarr archive, runs region-fill upsampling, and writes the result back to the archive under the target spacing path. Parameters ---------- root_path : Path Path to the root Zarr archive. seg_target : str Name of the segmentation target (e.g. ``"cells"`` or ``"nuclei"``). target_spacing : tuple[int, float] Target voxel spacing (Z, Y, X) in nanometers. src_spacing : tuple[int, int, int], optional Source (low-resolution) voxel spacing. Default is ``(100, 100, 100)``. erosion_iterations : int, optional Number of erosion steps applied to seeds before zooming. Default is 2. cellprob_threshold : float, optional Logit threshold for the foreground mask. Default is 0.0. store_chunks : tuple[int] | None, optional Chunk shape for the output Zarr array. If None, inherits source chunks. label_root : str | None, optional Optional prefix path for the label group within the archive. If None, labels are read from ``labels/{seg_target}/...``. n_workers : int, optional Number of worker threads for Dask resampling. Default is 4. """ root = zarr.open_group(root_path, mode="a") label_path = ( f"{label_root}/labels/{seg_target}" if label_root is not None else f"labels/{seg_target}" ) labels_lr_zarr: zarr.Array = root.get( f"{label_path}/masks/{dirname_from_spacing(src_spacing)}" ) # Try to load high resolution cellprob, and calculate it if it doesn't exit cellprob_hr_path = ( f"{label_path}/flows/cellprob/{dirname_from_spacing(target_spacing)}" ) cellprob_hr_zarr: zarr.Array = root.get(cellprob_hr_path) if cellprob_hr_zarr is None: resample_array( zarr_path=root_path, array_path=f"{label_path}/flows/cellprob/{dirname_from_spacing(src_spacing)}", target_spacing=target_spacing, n_workers=n_workers, ) cellprob_hr_zarr: zarr.Array = root.get(cellprob_hr_path) labels_lr: np.ndarray = labels_lr_zarr[:] cellprob_hr: np.ndarray = cellprob_hr_zarr[:] labels_hr = _upsample_masks_region_fill( labels_lr, cellprob_hr, erosion_iterations=erosion_iterations, cellprob_threshold=cellprob_threshold, ) processing = labels_lr_zarr.attrs.get("processing") + [ { "step": "upsample masks", "erosion_iterations": erosion_iterations, "cellprob_threshold": cellprob_threshold, } ] write_zarr( root, labels_hr, f"{label_path}/masks/{dirname_from_spacing(target_spacing)}", src_zarr=labels_lr_zarr, spacing=target_spacing, processing=processing, zarr_chunks=store_chunks if store_chunks else labels_lr_zarr.chunks, inputs=[labels_lr_zarr.path, cellprob_hr_zarr.path], )
[docs] def match_predictions(ground_truth: np.ndarray, predictions: np.ndarray) -> np.ndarray: """Remap predicted label IDs to match ground-truth label IDs. Uses Cellpose IoU matching to align predicted labels to ground-truth labels so that paired labels share the same integer ID. Parameters ---------- ground_truth : numpy.ndarray Integer label array of ground-truth segmentations. predictions : numpy.ndarray Integer label array of predicted segmentations. Returns ------- numpy.ndarray Relabeled prediction array where matched labels carry the same ID as the corresponding ground-truth label. """ _, matched = metrics.mask_ious(ground_truth, predictions) full_range = np.unique(predictions)[1:] missing = np.setdiff1d(full_range, matched).tolist() predictions_matched = predictions.copy() for val in missing: predictions_matched[predictions_matched == val] = 2 * predictions.max() + val for i, val in enumerate(matched): predictions_matched[predictions_matched == val] = predictions.max() + i + 1 predictions_matched[predictions_matched > 0] -= predictions.max() return predictions_matched
@gpu_dispatch(return_to_host=True) def _get_edges_and_nodes( labels: ArrayLike, cellprob: ArrayLike, edge_map: ArrayLike ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Find adjacencies and sample cellprob and edge values at region boundaries. Uses 6-connectivity grey dilation to detect voxels that lie on a boundary between two distinct labels. Parameters ---------- labels : ArrayLike Integer label volume (0 = background). cellprob : ArrayLike Cellprob logit volume, same shape as *labels*. edge_map : ArrayLike Normalized edge strength map in [0, 1], same shape as *labels*. Returns ------- labels_a : numpy.ndarray Label IDs on one side of each boundary voxel. labels_b : numpy.ndarray Label IDs on the other side of each boundary voxel. probs : numpy.ndarray Cellprob logit values at boundary voxels. edges : numpy.ndarray Edge map values at boundary voxels. nodes : numpy.ndarray All unique label IDs present in *labels*. counts : numpy.ndarray Voxel count for each unique label in *nodes*. """ # Build 6-connectivity element structure = ndi.generate_binary_structure(rank=3, connectivity=1) # Find edges dilated_labels = ndi.grey_dilation(labels, footprint=structure) boundary_mask = labels != dilated_labels labels_a = to_host(labels[boundary_mask]) labels_b = to_host(dilated_labels[boundary_mask]) probs = to_host(cellprob[boundary_mask]) edges = to_host(edge_map[boundary_mask]) # Find nodes all_labels, pixel_counts = xp.unique(labels, return_counts=True) nodes = to_host(all_labels) counts = to_host(pixel_counts) return labels_a, labels_b, probs, edges, nodes, counts
[docs] def build_rag( labels: np.ndarray, cellprob: np.ndarray, edge_map: np.ndarray ) -> graph.RAG: """ Builds a RAG of cellpose labels using GPU acceleration when possible. Parameters ---------- labels : np.ndarray Integer array containing the predicted labels cellprob : np.ndarray Cellprob array containing the cell probability logits edge_map : np.ndarray Edge map normalized to [0, 1] Returns ------- graph.RAG Region adjacency graph with the edge parameters - 'prob_weight': mean cell probability (probability of mean logit value) - 'edge_weight': mean edge probability - 'count': number of boundary voxels - 'weight': 1 - prob_weight + edge_weight """ # 1. Get all node and edge data (labels_a, labels_b, probs, edges, node_ids, node_counts) = _get_edges_and_nodes( labels, cellprob, edge_map ) # 2. Format Node Data node_attr = [ (label, {"labels": [label], "pixel_count": count, "total_surface_area": 0}) for label, count in zip(node_ids, node_counts) ] # 3. Format Edge Data df = pd.DataFrame({"prob": probs, "edge": edges}) pair = np.sort(np.stack([labels_a, labels_b], axis=1), axis=1) df["label_1"] = pair[:, 0] df["label_2"] = pair[:, 1] # Aggregate probability, edge, and count edge_stats = ( df.groupby(["label_1", "label_2"]) .agg( prob_weight=("prob", "mean"), edge_weight=("edge", "mean"), count=("prob", "size"), ) .reset_index() ) edge_stats["prob_weight"] = expit(edge_stats["prob_weight"]) edge_stats["weight"] = 1 - (edge_stats["prob_weight"] - edge_stats["edge_weight"]) # Create the edge list for the graph constructor edge_list = [] for _, row in edge_stats.iterrows(): # Correct for 1-sided detection true_count = int(row["count"]) * 2 edge_list.append( ( row["label_1"], row["label_2"], { "weight": row["weight"], "prob_weight": row["prob_weight"], "edge_weight": row["edge_weight"], "count": true_count, }, ) ) rag = graph.RAG(label_image=None, data=edge_list) rag.add_nodes_from(node_attr) return rag
[docs] def calc_surface_rag(rag: graph.RAG) -> dict[int, float]: """Calculate an approximation of label area from a region adjacency graph (RAG). Parameters ---------- rag : skimage.graph.RAG A region adjaceny graph. The RAG should include background nodes and use only face connectivity (connectivity=1). The function will use the "count" edge parameter to calculate the total surface of each label. Returns ------- dict[int, float] A dictionary in the form {label_num: total_surface} """ total_surface = {n: 0 for n in rag.nodes} for u, v, d in rag.edges(data=True): c = int(d.get("count", 0)) total_surface[u] += c total_surface[v] += c return total_surface
[docs] @gpu_dispatch(return_to_host=True) def gaussian_edge_map(image: ArrayLike, sigma: float | int) -> ArrayLike: """Calculate edge map of an image using Gaussian-smoothed gradient magnitude The edge map is clipped to 1st and 99th percentile and normalized. The function automatically uses GPU acceleration when available. Parameters ---------- image : ArrayLike The image to be analyzed. sigma : float | int The standard deviation of the Gaussian filter applied before gradient calculation. Returns ------- ArrayLike The edge map of the image, normalized to [0, 1]. """ edge_map = ndi.gaussian_gradient_magnitude(image, sigma, np.float32) p1, p99 = xp.percentile(edge_map, (1, 99)) edge_map = xp.clip((edge_map - p1) / (p99 - p1), 0, 1) return edge_map
[docs] def rag_to_df(rag: graph.RAG) -> pd.DataFrame: """Convert a RAG to a tidy DataFrame for inspection and debugging. Flattens the edge data of a RAG (as produced by `build_rag`) into a DataFrame where each row corresponds to one edge between two adjacent label regions. Parameters ---------- rag : graph.RAG Region adjacency graph, as returned by `build_rag`. Returns ------- pd.DataFrame One row per edge with columns: - ``u``, ``v`` : int — the two adjacent label IDs. - ``weight`` : float — merge cost (1 - prob_weight + edge_weight). - ``prob_weight`` : float — mean cell probability across boundary voxels (sigmoid of mean logit). - ``edge_weight`` : float — mean edge map value across boundary voxels. - ``count`` : int — number of boundary voxels between the two regions. See Also -------- build_rag : Constructs the RAG from Cellpose labels, cellprob logits, and an edge map. """ return pd.DataFrame([{"u": u, "v": v, **d} for u, v, d in rag.edges(data=True)])