"""
Cell and nucleus segmentation with Cellpose-SAM
"""
import re
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from cellpose import models, dynamics
import torch
import numpy as np
import zarr
from sphero_vem.io import write_zarr
from sphero_vem.utils import vprint
from sphero_vem.utils.config import BaseConfig, ProcessingStep
from sphero_vem.segmentation.cellpose.postptocessing import (
merge_labels,
decompose_flow,
expand_labels,
)
from sphero_vem.postprocessing import median_filter, guided_filter
[docs]
@dataclass
class CellposeFlowConfig(BaseConfig):
"""Configuration for Cellpose flow computation on a Zarr volume stack.
Parameters
----------
root_path : Path
Root directory of the Zarr archive.
model : str
Model identifier. Use ``"cpsam"`` for the pretrained Cellpose-SAM
model, or a path-encoded custom model name.
spacing_dir : str
Spacing directory name used to locate the source image array
(e.g. ``"100-100-100"``).
out_path : Path | None, optional
Output Zarr root path. If None, writes under *root_path*.
verbose : bool, optional
Enable progress messages. Default is True.
zarr_chunks : tuple[int] | None, optional
Chunk shape for output arrays. If None, inherits source chunks.
batch_size : int, optional
Inference batch size. Default is 64.
flow3D_smooth : int, optional
Gaussian smoothing iterations applied to 3D flows. Default is 2.
augment : bool, optional
Enable test-time augmentation. Default is False.
tile_overlap : float, optional
Fraction of overlap between inference tiles. Default is 0.3.
median_filter_cellprob : bool, optional
Apply a 3D median filter to the cellprob map. Default is False.
median_filter_size : int, optional
Kernel size for the median filter. Default is 3.
decompose_flows : bool, optional
Apply Helmholtz-Hodge flow decomposition. Default is False.
decompose_flows_pad_fraction : float, optional
Z-padding fraction for flow decomposition. Default is 0.3.
guided_filter_cellprob : bool, optional
Apply a guided filter to the cellprob map. Default is False.
guided_filter_radius : int, optional
Half-window radius for the guided filter. Default is 8.
guided_filter_eps : float, optional
Regularization parameter for the guided filter. Default is 1e-2.
save_raw_flows : bool, optional
Save unprocessed flows alongside processed ones. Default is False.
"""
root_path: Path
model: str
spacing_dir: str
out_path: Path | None = None
verbose: bool = True
zarr_chunks: tuple[int] | None = None
batch_size: int = 64
flow3D_smooth: int = 2
augment: bool = False
tile_overlap: float = 0.3
median_filter_cellprob: bool = False
median_filter_size: int = 3
decompose_flows: bool = False
decompose_flows_pad_fraction: float = 0.3
guided_filter_cellprob: bool = False
guided_filter_radius: int = 8
guided_filter_eps: float = 1e-2
save_raw_flows: bool = False
seg_target: str = field(init=False)
model_dir: Path = field(init=False)
spacing: list[int | float] = field(init=False)
src_zarr: zarr.Array = field(init=False)
EXCLUDED_PROCESSING_FIELDS = set(
[
"root_path",
"spacing_dir",
"out_path",
"verbose",
"zarr_chunks",
"model_dir",
"spacing",
"src_zarr",
"save_raw_flows",
]
)
def __post_init__(self):
"""Derive ``seg_target``, ``model_dir``, ``src_zarr``, and ``spacing``."""
# Allow loading pretrained model
if self.model == "cpsam":
# Set model_dir as empty path for compatibility with class init
self.model_dir = Path("")
self.seg_target = "cells"
else:
self.model_dir = Path(
f"data/models/cellpose/{self.model}/models/{self.model}"
)
self.seg_target = re.search(r"cellposeSAM-(\w+)-", self.model).group(1)
self.src_zarr = zarr.open_array(
self.root_path / f"images/{self.spacing_dir}", mode="r"
)
self.spacing = self.src_zarr.attrs.get("spacing")
# If out_dir is not specified, save under the Zarr path of the images
if not self.out_path:
self.out_path = self.root_path
else:
self.out_path.parent.mkdir(parents=True, exist_ok=True)
if not self.zarr_chunks:
self.zarr_chunks = self.src_zarr.chunks
[docs]
def compute_raw_flows(config: CellposeFlowConfig) -> tuple[np.ndarray, np.ndarray]:
"""
Compute raw cellpose flows without postprocessing.
This function performs model inference only, returning the raw displacement
field and cell probability. No postprocessing, filtering, or saving is performed.
Parameters
----------
config : CellposeFlowConfig
Configuration containing model parameters, image source, and inference settings.
Only the following fields are used:
- src_zarr: Source image array
- model, model_dir: Model specification
- batch_size, tile_overlap, flow3D_smooth, augment: Inference parameters
- verbose: Logging control
Returns
-------
dP : np.ndarray
Displacement field with shape (3, Z, Y, X) in float16.
cellprob : np.ndarray
Cell probability logits with shape (Z, Y, X) in float16.
Notes
-----
- GPU memory is explicitly freed after inference
- Arrays are returned in float16 for memory efficiency
- This function does not modify any files or zarr stores
See Also
--------
postprocess_flows : For postprocessing raw flows
calculate_flows : For complete flow calculation including postprocessing
"""
image: np.ndarray = config.src_zarr[:]
pretrained_model = "cpsam" if config.model == "cpsam" else config.model_dir
cellpose_model = models.CellposeModel(gpu=True, pretrained_model=pretrained_model)
time_start = datetime.now()
vprint(f"Starting segmentation at {time_start}", config.verbose)
# Shape and channel settings to handle 2D and 3D images correctly
settings_ndim = {
2: {"z_axis": None, "channel_axis": None, "do_3D": False},
3: {
"z_axis": 0,
"channel_axis": None,
"do_3D": True,
},
}
with torch.inference_mode():
_, flows, _ = cellpose_model.eval(
image,
**settings_ndim[image.ndim],
batch_size=config.batch_size,
tile_overlap=config.tile_overlap,
flow3D_smooth=config.flow3D_smooth,
augment=config.augment,
compute_masks=False,
)
# Ensure GPU memory is garbage collected
cellpose_model.net.to("cpu")
del cellpose_model
torch.cuda.empty_cache()
time_finish = datetime.now()
vprint(f"Completed segmentation at {time_finish}", config.verbose)
vprint(f"Elapsed time: {time_finish - time_start}", config.verbose)
dP = np.ascontiguousarray(flows[1]).astype(np.float16, copy=False)
cellprob = np.ascontiguousarray(flows[2]).astype(np.float16, copy=False)
# Free the flows tuple to release any GPU references
del flows
torch.cuda.empty_cache()
return dP, cellprob
[docs]
def postprocess_flows(
config: CellposeFlowConfig,
dP: np.ndarray,
cellprob: np.ndarray,
save_root: zarr.Group | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Postprocess raw cellpose flows and save to zarr.
This function applies optional filtering and decomposition steps to raw flows,
then saves both raw (if requested) and processed flows to zarr. This function
is intended for debugging and iterative development of postprocessing pipelines.
Parameters
----------
config : CellposeFlowConfig
Configuration containing postprocessing parameters and save paths.
Relevant fields:
- out_path: Destination zarr store path
- seg_target: Target label group name
- spacing_dir: Spacing directory name
- save_raw_flows: Whether to save unprocessed flows
- median_filter_cellprob: Whether to apply median filter
- median_filter_size: Median filter size
- guided_filter_cellprob: Whether to apply guided filter
- guided_filter_radius, guided_filter_eps: Guided filter parameters
- decompose_flows: Whether to decompose flows via Helmholtz-Hodge
- decompose_flows_pad_fraction: Padding for flow decomposition
- zarr_chunks: Chunk size for zarr arrays
- verbose: Logging control
dP : np.ndarray
Raw displacement field with shape (3, Z, Y, X), or (2, Y, X) if 2D.
cellprob : np.ndarray
Raw cell probability logits with shape (Z, Y, X), or (Y, X) if 2D.
save_root : zarr.Group | None, optional
Pre-opened zarr group for saving. If None, opens config.out_path.
This allows external control over zarr cleanup operations.
Returns
-------
dP_processed : np.ndarray
Processed displacement field (after optional decomposition).
cellprob_processed : np.ndarray
Processed cell probability (after optional filtering).
Notes
-----
- Does NOT perform zarr group cleanup (caller's responsibility)
- Suitable for iterative debugging of postprocessing parameters
- GPU operations (decompose_flow) clean up their own memory
- All arrays are saved as float16 for storage efficiency
See Also
--------
compute_raw_flows : For raw flow computation
calculate_flows : For complete flow calculation including zarr cleanup
"""
# Input validation
if dP.shape[0] != cellprob.ndim:
raise ValueError(
f"dP must have shape ({cellprob.ndim}, Z, Y, X), got {dP.shape}"
)
if dP.shape[1:] != cellprob.shape:
raise ValueError(
f"dP and cellprob spatial dimensions must match. "
f"Got dP shape {dP.shape} and cellprob shape {cellprob.shape}"
)
vprint("Starting flow postprocessing", config.verbose)
# Open zarr if not provided
if save_root is None:
save_root = zarr.open_group(config.out_path, mode="a")
target_group = f"labels/{config.seg_target}"
processing = ProcessingStep.from_config("segmentation", config)
if config.save_raw_flows:
vprint("Saving raw flows", config.verbose)
write_zarr(
save_root,
cellprob,
f"{target_group}/flows/cellprob-raw/{config.spacing_dir}",
src_zarr=config.src_zarr,
processing=processing,
zarr_chunks=config.zarr_chunks,
dtype="f2",
)
write_zarr(
save_root,
dP,
f"{target_group}/flows/dP-raw/{config.spacing_dir}",
src_zarr=config.src_zarr,
processing=processing,
zarr_chunks=(3, *config.zarr_chunks),
dtype="f2",
)
if config.median_filter_cellprob:
cellprob = median_filter(cellprob, config.median_filter_size)
# Guided filter should be done using the raw dP.
if config.guided_filter_cellprob:
dP_mag = np.sqrt(np.sum(dP**2, axis=0))
cellprob = guided_filter(
cellprob,
guide=dP_mag / dP_mag.max(),
radius=config.guided_filter_radius,
eps=config.guided_filter_eps,
)
if config.decompose_flows:
dP = decompose_flow(
dP, config.decompose_flows_pad_fraction, torch.device("cuda")
)
# Saving processed flows
vprint("Saving processed flows", config.verbose)
write_zarr(
save_root,
cellprob,
f"{target_group}/flows/cellprob/{config.spacing_dir}",
src_zarr=config.src_zarr,
processing=processing,
zarr_chunks=config.zarr_chunks,
dtype="f2",
)
write_zarr(
save_root,
dP,
f"{target_group}/flows/dP/{config.spacing_dir}",
src_zarr=config.src_zarr,
processing=processing,
zarr_chunks=(3, *config.zarr_chunks),
dtype="f2",
)
return dP, cellprob
[docs]
def calculate_flows(config: CellposeFlowConfig) -> None:
"""
Segment volume stack using cellpose: compute flows and postprocess.
This is the main entry point for cellpose flow calculation. It performs model
inference, postprocessing, and saving in a single call. For debugging or
iterative development, use compute_raw_flows() and postprocess_flows() separately.
Parameters
----------
config : CellposeFlowConfig
Complete configuration for flow calculation and postprocessing.
Notes
-----
- Deletes existing labels/{seg_target} group to ensure clean state
- Calls compute_raw_flows() followed by postprocess_flows()
- Maintains backward compatibility with existing scripts
See Also
--------
compute_raw_flows : For inference-only workflow
postprocess_flows : For postprocessing pre-computed flows
calculate_masks : For generating masks from processed flows
"""
# Step 1: Compute raw flows
dP, cellprob = compute_raw_flows(config)
# Step 2: Prepare zarr (cleanup existing group)
save_root = zarr.open_group(config.out_path, mode="a")
target_group = f"labels/{config.seg_target}"
if save_root.get(target_group) is not None:
save_root.__delitem__(target_group)
# Step 3: Postprocess and save
postprocess_flows(config, dP, cellprob, save_root=save_root)
[docs]
@dataclass
class CellposeMaskConfig(BaseConfig):
"""Configuration for generating segmentation masks from Cellpose flows.
Parameters
----------
root_path : Path
Root directory of the Zarr archive.
seg_target : str
Segmentation target name (e.g. ``"cells"`` or ``"nuclei"``).
spacing_dir : str, optional
Spacing directory that identifies the flow arrays. Default is
``"100-100-100"``.
label_root : str | None, optional
Optional prefix group path for non-standard label locations within
the Zarr archive. Default is None.
niter : int, optional
Number of Euler integration steps in the flow dynamics solver.
Default is 200.
cellprob_threshold : float, optional
Cellprob logit threshold; voxels below this value are excluded.
Default is -0.5.
flow_threshold : float, optional
Maximum allowed flow error for retaining a mask. Default is 0.4.
min_diam : float, optional
Minimum object diameter in micrometers used to derive ``min_size``
in voxels. Default is 3.
expand_labels : bool, optional
Expand labels into the foreground mask after mask generation.
Default is False.
max_expansion_steps : int, optional
Maximum dilation iterations for label expansion. Default is 5.
merge_masks : bool, optional
Merge adjacent under-segmented labels via RAG-based merging.
Default is True.
gaussian_edge_sigma : float, optional
Gaussian sigma for edge map computation during merging. Default 2.0.
merge_weight_threshold : float, optional
Maximum edge weight for a merge to be accepted. Default is 0.2.
merge_contact_threshold : float, optional
Minimum relative contact area for a merge to be accepted. Default 0.2.
device : str, optional
Torch device string for mask generation. Default is ``"cuda"``.
zarr_chunks : tuple[int] | None, optional
Chunk shape for the output mask array. If None, inherits source chunks.
"""
root_path: Path
seg_target: str
spacing_dir: str = "100-100-100"
label_root: str | None = None
niter: int = 200
cellprob_threshold: float = -0.5
flow_threshold: float = 0.4
min_diam: float = 3
expand_labels: bool = False
max_expansion_steps: int = 5
merge_masks: bool = True
gaussian_edge_sigma: float = 2.0
merge_weight_threshold: float = 0.2
merge_contact_threshold: float = 0.2
device: str = "cuda"
zarr_chunks: tuple[int] | None = None
min_size: int = field(init=False)
spacing: list[int | float] = field(init=False)
label_path: str = field(init=False)
EXCLUDED_PROCESSING_FIELDS = set(
["root_path", "device", "zarr_chunks", "label_root", "label_path"]
)
def __post_init__(self):
"""Derive ``spacing``, ``min_size``, ``zarr_chunks``, and ``label_path``."""
# Celculate min_size in pixel from min_diam in micrometers
src_zarr = zarr.open_array(
self.root_path / f"images/{self.spacing_dir}", mode="r"
)
self.spacing = src_zarr.attrs.get("spacing")
# Determine whether min_size should be area or volume
if len(self.spacing) == 2:
pixel_um = np.prod(self.spacing) * 1e-6
min_area_um = np.pi * (self.min_diam / 2) ** 2
self.min_size = int(min_area_um / pixel_um)
elif len(self.spacing) == 3:
voxel_um = np.prod(self.spacing) * 1e-9
min_vol_um = 4 / 3 * np.pi * (self.min_diam / 2) ** 3
self.min_size = int(min_vol_um / voxel_um)
if not self.zarr_chunks:
self.zarr_chunks = src_zarr.chunks
# Process label_root to cover cases where labels are not in a standard location
# within the zarr store.
# One example of this is pretrained labels, which are saved under
# dataset.zarr/pretrained/
self.label_path = (
f"{self.label_root}/labels/{self.seg_target}"
if self.label_root is not None
else f"labels/{self.seg_target}"
)
[docs]
def calculate_masks(config: CellposeMaskConfig) -> None:
"""Generate segmentation masks from pre-computed Cellpose flows.
Loads cellprob and dP arrays from the Zarr archive, runs the Cellpose
dynamics solver, applies optional label expansion and RAG-based merging,
and writes the final mask array back to the archive.
Parameters
----------
config : CellposeMaskConfig
Configuration specifying flow sources, solver parameters, and
post-processing options.
"""
device = torch.device(config.device)
root = zarr.open_group(config.root_path, mode="a")
cellprob_zarr = root.get(f"{config.label_path}/flows/cellprob/{config.spacing_dir}")
dp_zarr = root.get(f"{config.label_path}/flows/dP/{config.spacing_dir}")
cellprob: np.ndarray = cellprob_zarr[:]
dP: np.ndarray = dp_zarr[:]
do_3d = True if cellprob.ndim == 3 else False
masks = dynamics.compute_masks(
dP=dP,
cellprob=cellprob,
niter=config.niter,
cellprob_threshold=config.cellprob_threshold,
flow_threshold=config.flow_threshold,
do_3D=do_3d,
min_size=config.min_size,
device=device,
)
# Post-process labels
inputs = [cellprob_zarr.path, dp_zarr.path]
if config.expand_labels:
masks = expand_labels(
masks,
cellprob_logits=cellprob,
cellprob_threshold=0,
max_expansion_steps=config.max_expansion_steps,
)
if config.merge_masks:
image_arr = root.get(f"images/{config.spacing_dir}")
inputs.append(image_arr.path)
image: np.ndarray = image_arr[:]
masks, _ = merge_labels(
masks,
cellprob=cellprob,
image=image,
rel_contact_thresh=config.merge_contact_threshold,
weight_thresh=config.merge_weight_threshold,
sigma=config.gaussian_edge_sigma,
)
processing = ProcessingStep.from_config("segmentation-mask-generation", config)
write_zarr(
root,
masks,
f"{config.label_path}/masks/{config.spacing_dir}",
src_zarr=cellprob_zarr,
dtype=np.uint8 if masks.max() <= 255 else np.uint16,
inputs=inputs,
processing=processing,
)