Source code for sphero_vem.denoising

"""
Functions for denoising images, based on CAREamics.
"""

import warnings
from pathlib import Path
from typing import Literal, ClassVar
from dataclasses import dataclass, field

from tqdm import tqdm
import numpy as np
import zarr
import dask.array as da
from sklearn.model_selection import train_test_split

import torch
from torch import nn

from careamics import CAREamist, Configuration
from careamics.config import create_n2v_configuration, UNetConfig
from careamics.models.layers import Conv_Block

from sphero_vem.io import write_zarr, _create_zarr_array, _write_zarr_metadata
from sphero_vem.utils import (
    timestamp,
    BaseConfig,
    ProcessingStep,
    temporary_zarr,
    dirname_from_spacing,
)
from sphero_vem.utils.logging import (
    ArtifactsCallback,
    HyperparamsCallback,
    setup_wanb_env,
    suppress_logging,
)


[docs] @dataclass class DenoisingConfig(BaseConfig): """Configuration for Noise2Void training via CAREamics. Parameters ---------- root_path : Path Path to the root Zarr archive containing the source data. src_path : str Path within the Zarr archive to the source image array. num_images : int, optional Number of 2-D slices to load for training. Default is 10. val_split : float, optional Fraction of slices to use for validation. Default is 0.2. random_state : int, optional Random seed for the train/validation split. Default is 42. batch_size : int, optional Training mini-batch size. Default is 128. patch_size : int, optional Spatial size of square patches extracted from slices. Default is 64. epochs : int, optional Number of training epochs. Default is 100. unet_depth : int, optional Depth of the U-Net encoder. Default is 2. unet_num_channels_init : int, optional Number of feature channels in the first U-Net encoder layer. Default is 32. n2v2 : bool, optional If True, use N2V2 blind-spot strategy instead of standard N2V. Default is False. num_workers : int, optional Number of data-loader worker processes. Default is 16. wandb_project : str, optional Weights & Biases project name for experiment tracking. Default is ``"denoising"``. work_root : Path, optional Root directory for model checkpoints and configs. Default is ``Path("data/models/n2v")``. model_name : str | None, optional Unique name for this training run. If None, a timestamp-based name is generated automatically. Default is None. """ root_path: Path src_path: str num_images: int = 10 val_split: float = 0.2 random_state: int = 42 # N2V hyperparameters batch_size: int = 128 patch_size: int = 64 epochs: int = 100 unet_depth: int = 2 unet_num_channels_init: int = 32 n2v2: bool = False # Other params num_workers: int = 16 wandb_project: str = "denoising" work_root: Path = Path("data/models/n2v") model_name: str | None = None work_dir: Path = field(init=False) n2v_config: Configuration = field(init=False) EXCLUDED_JSON_FIELDS = set(["n2v_config"]) EXCLUDED_PROCESSING_FIELDS = set( [ "root_path", "src_path", "num_workers", "n2v_config", "work_dir", "work_root", "wandb_project", ] ) def __post_init__(self): super().__post_init__() if not self.model_name: self.model_name = f"n2v-{timestamp()}" self.work_dir = self.work_root / self.model_name self.n2v_config = create_n2v_configuration( experiment_name=self.model_name, data_type="array", axes="SYX", patch_size=[self.patch_size, self.patch_size], batch_size=self.batch_size, num_epochs=self.epochs, use_n2v2=self.n2v2, logger="wandb", model_params={ "depth": self.unet_depth, "num_channels_init": self.unet_num_channels_init, }, checkpoint_params={"save_top_k": 1}, train_dataloader_params={"num_workers": self.num_workers}, val_dataloader_params={"num_workers": self.num_workers}, )
[docs] def save_n2v_config(self, filepath: str | Path) -> None: """Saves the N2V config class to a JSON file.""" with open(filepath, "w") as file: file.write(self.n2v_config.model_dump_json(indent=4))
[docs] def train_n2v(config: DenoisingConfig) -> None: """Train a Noise2Void model using the parameters in *config*. Loads 2-D slices from a Zarr array, splits them into training and validation sets, saves config files, and runs the CAREamics training loop with Weights & Biases logging. Parameters ---------- config : DenoisingConfig Training configuration. The Zarr archive at ``config.root_path`` must be readable and the array at ``config.src_path`` must exist. """ root = zarr.open_group(config.root_path, mode="r") src_array = root.get(config.src_path) # Load training and validation arrays into memory train_data, val_data = train_test_split( src_array[: config.num_images], test_size=config.val_split, random_state=config.random_state, ) # Save config files config_path = config.work_dir / "config.json" n2v_config_path = config.work_dir / "n2v_config.json" config.work_dir.mkdir(exist_ok=True, parents=True) config.to_json(config_path) config.save_n2v_config(n2v_config_path) # Set up callbacks callback_params = HyperparamsCallback(config.processing_metadata()) callback_artifacts = ArtifactsCallback([config_path, n2v_config_path]) # Run training setup_wanb_env(config.wandb_project) torch.set_float32_matmul_precision("high") careamist = CAREamist( config.n2v_config, work_dir=config.work_dir, callbacks=[callback_params, callback_artifacts], ) careamist.train(train_source=train_data, val_source=val_data)
def _patch_decoder_blocks(unet: nn.Module, unet_cfg: UNetConfig) -> None: """Patch UnetDecoder.decoder_blocks in-place to match CAREamics 0.0.10 architecture. In older versions of Careamics (at least <=0.0.10), decoder Conv_Blocks receive concatenated skip connections as input, producing different input channel counts. This function rebuilds only the decoder_blocks ModuleList with the old channel arithmetic, leaving the bottleneck, upsampling, and final conv untouched. This is intended to allow running older N2V models on the newer versions of the library, new models should be trained using the standard Careamics API. Parameters ---------- unet : nn.Module UNet instance loaded with the current CAREamics version. unet_cfg : UNetConfig UNet configuration from the CAREamics checkpoint. """ depth = unet_cfg.depth num_channels_init = unet_cfg.num_channels_init conv_dim = unet_cfg.conv_dims use_batch_norm = unet_cfg.use_batch_norm groups = unet_cfg.in_channels if unet_cfg.independent_channels else 1 upsampling = nn.Upsample( scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear" ) decoder_blocks: list[nn.Module] = [] for n in range(depth): decoder_blocks.append(upsampling) in_channels = (num_channels_init * 2 ** (depth - n)) * groups out_channels = in_channels // 2 decoder_blocks.append( Conv_Block( conv_dim, in_channels=(in_channels + in_channels // 2 if n > 0 else in_channels), out_channels=out_channels, intermediate_channel_multiplier=2, dropout_perc=0.0, activation="ReLU", use_batch_norm=use_batch_norm, groups=groups, ) ) unet.decoder.decoder_blocks = nn.ModuleList(decoder_blocks) def _load_old_model(model_path: Path) -> CAREamist: """Load N2V trained with older versions of Careamics (<=0.0.10) Parameters ---------- model_path : Path Path to the checkpoint Returns ------- CAREamist Loaded CAREamist model. """ ckpt = torch.load(model_path, map_location="cpu") cfg = Configuration.model_validate(ckpt["hyper_parameters"]) unet_cfg = cfg.algorithm_config.model # Patch data_type for array-based prediction cfg.data_config.data_type = "array" careamist = CAREamist(cfg) _patch_decoder_blocks(unet=careamist.model.model, unet_cfg=unet_cfg) unet_state_dict = { k.removeprefix("model."): v for k, v in ckpt["state_dict"].items() if k.startswith("model.") } careamist.model.model.load_state_dict(unet_state_dict) return careamist def _load_model( model_name: str, model_root: Path = Path("data/models/n2v"), suppress_pbar: bool = True, ) -> CAREamist: """Load N2V model with the specified name. Expects exactly one best checkpoint. If multiple are present, returns the first found after sorting. Parameters ---------- model_name : str Name of the model to load. model_root : Path Root directory containing trained models. Default is Path("data/models/n2v"). suppress_pbar : bool If True, suppress progress bars during prediction. Default is True. Returns ------- CAREamist Loaded CAREamist model. """ ckpt_dir = model_root / f"{model_name}/checkpoints" model_path = sorted(ckpt_dir.glob(f"{model_name}*.ckpt"))[0] try: careamist = CAREamist(model_path) except RuntimeError: print( "Model architecture incopatible with current Careamics version.\n" "Attempting to load model with Careamics v0.0.10 UNet architecture" ) careamist = _load_old_model(model_path) # Suppress logging predictions to WandB careamist.trainer.loggers = [] careamist.trainer.logger = None if suppress_pbar: careamist.trainer.callbacks = [] return careamist
[docs] @dataclass class DenoisingStats: """Statistics accumulated during the denoising pass. Tracks global intensity range of the denoised output and the residual histogram (original - denoised) in the original intensity space, before any rescaling. Parameters ---------- global_min : float Running minimum of denoised float32 values across all slices. global_max : float Running maximum of denoised float32 values across all slices. residual_counts : np.ndarray Histogram counts of shape (511,), covering residuals in [-255, 255]. Notes ----- Residuals are computed as original (uint8) - denoised (float32), rounded and clipped to [-255, 255]. A well-behaved residual histogram should be approximately zero-mean and Gaussian. A small negative bias is expected due to N2V's blind-spot averaging. """ RESIDUAL_MIN: ClassVar[int] = -255 RESIDUAL_MAX: ClassVar[int] = 255 RESIDUAL_BINS: ClassVar[int] = 511 BIN_CENTERS: ClassVar[np.ndarray] = np.arange(-255, 256) global_min: float = float("inf") global_max: float = float("-inf") residual_counts: np.ndarray = field( default_factory=lambda: np.zeros(511, dtype=np.int64) )
[docs] def update(self, original: np.ndarray, denoised: np.ndarray) -> None: """Update statistics with a new slice. Parameters ---------- original : np.ndarray Original uint8 image. denoised : np.ndarray Denoised float32 image, before any rescaling. """ self.global_min = min(self.global_min, float(denoised.min())) self.global_max = max(self.global_max, float(denoised.max())) residual = np.clip( np.round(original.astype(np.float32) - denoised), self.RESIDUAL_MIN, self.RESIDUAL_MAX, ).astype(np.int16) self.residual_counts += np.bincount( residual.ravel() - self.RESIDUAL_MIN, minlength=self.RESIDUAL_BINS, )
[docs] def denoise_image( image: np.ndarray, careamist: CAREamist, tile_size: tuple[int, ...], tile_overlap: tuple[int, ...], batch_size: int, num_workers: int, rescale: bool = False, stats: DenoisingStats | None = None, ) -> np.ndarray: """Denoise a single YX image. Parameters ---------- image : np.ndarray 2D image to denoise (YX). careamist : CAREamist Trained CAREamist model. tile_size : tuple[int, ...] Tile size for prediction. tile_overlap : tuple[int, ...] Overlap between tiles. batch_size : int Number of tiles to process in parallel. num_workers : int Number of dataloader workers. rescale : bool If True, rescale denoised output to uint8 [0, 255] using per-slice min/max. Default is False. stats : DenoisingStats | None If provided, updated in-place with global min/max and residual histogram. Stats are accumulated from the float32 denoised image before rescaling. Default is None. Returns ------- np.ndarray Denoised image. If rescale=True, returns uint8, otherwise float32. """ with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=".*predict_dataloader.*num_workers.*", category=UserWarning, ) with suppress_logging(): denoised = careamist.predict( source=image, tile_size=tile_size, tile_overlap=tile_overlap, axes="YX", data_type="array", batch_size=batch_size, dataloader_params={"num_workers": num_workers}, )[0].squeeze() if stats is not None: stats.update(image, denoised) if rescale: img_min = denoised.min() img_max = denoised.max() denoised = ( ((denoised - img_min) / (img_max - img_min) * 255) .clip(0, 255) .astype(np.uint8) ) return denoised
[docs] def denoise_stack( root_path: Path, src_path: str, model_name: str, dst_group: str = "images/denoised", model_root: Path = Path("data/models/n2v"), tile_size: tuple[int, int] = (512, 512), tile_overlap: tuple[int, int] = (48, 48), batch_size: int = 64, num_workers: int = 16, temp_dir: Path | str | None = Path("data/tmp"), rescale_mode: Literal["per_slice", "global"] = "per_slice", ) -> None: """Denoise volume with Noise2Void and rescale to uint8. Performs denoising with either global or per-slice intensity rescaling. Both modes accumulate a residual histogram saved as an npz file alongside the output zarr. Parameters ---------- root_path : Path Path to the zarr root archive. src_path : str Path to the source array within the zarr archive. model_name : str Name of the trained N2V model. Used to locate the model checkpoint and configuration file under `model_root`. dst_group : str Destination group within the zarr archive. The output array path is constructed as `{dst_group}/{dirname_from_spacing(spacing)}`. Default is "images/denoised". model_root : Path Root directory containing trained models. Each model should be in a subdirectory with a config.json and checkpoint file. Default is Path("data/models/n2v"). tile_size : tuple[int, int] Tile size in pixels (Y, X) for prediction. Default is (512, 512). tile_overlap : tuple[int, int] Overlap in pixels (Y, X) between adjacent tiles to avoid boundary artifacts. Default is (48, 48). batch_size : int Number of tiles to process in parallel on the GPU. Default is 64. num_workers : int Number of dataloader workers for tile loading. Default is 16. temp_dir : Path | str | None Directory for intermediate float32 zarr storage. Only used when `rescale_mode='global'`. Should be on fast local storage (SSD). Default is Path("data/tmp"). rescale_mode : Literal["per_slice", "global"] If `'global'`, use global min/max across all slices for rescaling (two-pass, requires temporary zarr storage). If `'per_slice'`, rescale each slice independently using per-slice min/max (single-pass, no temporary storage). Default is `'per_slice'`. Notes ----- When `rescale_mode='global'`, the intermediate zarr is uncompressed and can be large (4 bytes per voxel). Ensure sufficient disk space in `temp_dir`. The residual histogram is saved to `{root_path}/images/tables/denoised-residual-hist.npz`. """ torch.set_float32_matmul_precision("high") root = zarr.open_group(root_path, mode="a") src_zarr = root.get(src_path) config = None try: config = DenoisingConfig.from_json(model_root / f"{model_name}/config.json") except FileNotFoundError: pass careamist = _load_model(model_name=model_name, model_root=model_root) dst_path = f"{dst_group}/{dirname_from_spacing(src_zarr.attrs['spacing'])}" hist_path = root_path / "images/tables/denoised-residual-hist.npz" hist_path.parent.mkdir(exist_ok=True, parents=True) if rescale_mode == "global": _denoise_global_rescale( src_zarr=src_zarr, careamist=careamist, root=root, dst_path=dst_path, config=config, tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, num_workers=num_workers, temp_dir=temp_dir, hist_path=hist_path, ) elif rescale_mode == "per_slice": _denoise_per_slice_rescale( src_zarr=src_zarr, careamist=careamist, root=root, dst_path=dst_path, config=config, tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, num_workers=num_workers, hist_path=hist_path, ) else: raise ValueError( f"Invalid value {rescale_mode} for rescale_mode. " "Valid options are 'global' and 'per_slice'." )
def _save_residual_histogram(stats: DenoisingStats, hist_path: Path) -> None: """Save residual histogram counts and bin centers to an npz file. Parameters ---------- stats : DenoisingStats Accumulated denoising statistics. hist_path : Path Output path for the npz file. """ np.savez( hist_path, counts=stats.residual_counts, bin_centers=stats.BIN_CENTERS, ) def _denoise_global_rescale( src_zarr: zarr.Array, careamist: CAREamist, root: zarr.Group, dst_path: str, config: DenoisingConfig | None, tile_size: tuple[int, int], tile_overlap: tuple[int, int], batch_size: int, num_workers: int, temp_dir: Path | str | None, hist_path: Path, ) -> None: """Two-pass denoising with global rescaling. Parameters ---------- src_zarr : zarr.Array Source zarr array. careamist : CAREamist Trained CAREamist model. root : zarr.Group Root zarr group. dst_path : str Destination path within the zarr archive. config : DenoisingConfig | None Denoising configuration. If None, denoising training steps will not be written in the zarr metadata. tile_size : tuple[int, int] Tile size in pixels (Y, X). tile_overlap : tuple[int, int] Overlap in pixels (Y, X) between tiles. batch_size : int Number of tiles to process in parallel. num_workers : int Number of dataloader workers. temp_dir : Path | str | None Directory for intermediate float32 zarr storage. hist_path : Path Output path for the residual histogram npz file. """ stats = DenoisingStats() with temporary_zarr( shape=src_zarr.shape, chunks=(1, *src_zarr.shape[1:]), dtype=np.float32, dir=temp_dir, ) as intermediate: # Pass 1: Denoise, accumulate stats, store float32 for z in tqdm(range(src_zarr.shape[0]), desc="Denoising"): denoised = denoise_image( src_zarr[z], careamist, tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, num_workers=num_workers, stats=stats, ) intermediate[z] = denoised # Pass 2: Rescale globally and cast to uint8 intermediate_dask = da.from_zarr(intermediate) rescaled = ( ( (intermediate_dask - stats.global_min) / (stats.global_max - stats.global_min) * 255 ) .clip(0, 255) .astype(np.uint8) ) processing = [] if config is not None: processing.append(ProcessingStep.from_config("denoising-train", config)) processing.append( ProcessingStep.manual( "denoising-predict", { "tile_size": tile_size, "tile_overlap": tile_overlap, "global_min": stats.global_min, "global_max": stats.global_max, "rescale_mode": "global", }, ), ) write_zarr( root=root, array=rescaled, dst_path=dst_path, src_zarr=src_zarr, dtype=np.uint8, processing=processing, ) _save_residual_histogram(stats, hist_path) def _denoise_per_slice_rescale( src_zarr: zarr.Array, careamist: CAREamist, root: zarr.Group, dst_path: str, config: DenoisingConfig | None, tile_size: tuple[int, int], tile_overlap: tuple[int, int], batch_size: int, num_workers: int, hist_path: Path, ) -> None: """Single-pass denoising with per-slice rescaling. Parameters ---------- src_zarr : zarr.Array Source zarr array. careamist : CAREamist Trained CAREamist model. root : zarr.Group Root zarr group. dst_path : str Destination path within the zarr archive. config : DenoisingConfig | None Denoising configuration. If None, denoising training steps will not be written in the zarr metadata. tile_size : tuple[int, int] Tile size in pixels (Y, X). tile_overlap : tuple[int, int] Overlap in pixels (Y, X) between tiles. batch_size : int Number of tiles to process in parallel. num_workers : int Number of dataloader workers. hist_path : Path Output path for the residual histogram npz file. """ stats = DenoisingStats() processing = [] if config is not None: processing.append(ProcessingStep.from_config("denoising-train", config)) processing.append( ProcessingStep.manual( "denoising-predict", { "tile_size": tile_size, "tile_overlap": tile_overlap, "rescale_mode": "per_slice", }, ), ) dst_zarr = _create_zarr_array( root=root, dst_path=dst_path, shape=src_zarr.shape, chunks=src_zarr.chunks, dtype=np.uint8, ) for z in tqdm(range(src_zarr.shape[0]), desc="Denoising"): image = src_zarr[z] denoised = denoise_image( image, careamist, tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, num_workers=num_workers, stats=stats, ) img_min = denoised.min() img_max = denoised.max() dst_zarr[z] = ( ((denoised - img_min) / (img_max - img_min) * 255) .clip(0, 255) .astype(np.uint8) ) _write_zarr_metadata( root=root, dst_zarr=dst_zarr, src_zarr=src_zarr, processing=processing, ) _save_residual_histogram(stats, hist_path)