"""
Nanoparticle segmentation
"""
import json
from typing import Self
from pathlib import Path
from dataclasses import dataclass, field
from tqdm import tqdm
import numpy as np
import zarr
from scipy.optimize import minimize_scalar
from scipy.ndimage import find_objects
import skimage as ski_cpu
from sphero_vem.utils import (
BaseConfig,
ProcessingStep,
CustomJSONEncoder,
dirname_from_spacing,
timestamp,
)
from sphero_vem.utils.accelerator import (
xp,
gpu_dispatch,
ArrayLike,
ndi,
to_host,
to_device,
)
from sphero_vem.io import write_zarr, _create_zarr_array, _write_zarr_metadata
from sphero_vem.postprocessing import binary_closing, filter_and_relabel
from sphero_vem.segmentation.np.utils import bincount_ubyte
[docs]
@dataclass
class NanoparticleSegConfig(BaseConfig):
"""Configuration for nanoparticle (NP) segmentation via empirical Bayes EM.
Parameters
----------
root_path : Path
Path to the root zarr store containing the image stack.
spacing_dir : str
Array path within ``images/`` (e.g. ``"50-10-10"``).
verbose : bool, optional
Enable progress bars and diagnostic output. Default is False.
max_iter : int, optional
Maximum number of EM iterations. Default is 10000.
eps : float, optional
Numerical floor added to PMFs to prevent log(0). Default is 1e-12.
nll_tol : float, optional
Convergence tolerance: stops when the per-iteration decrease in NLL
falls below this value. Default is 1e-10.
sampling_step : int, optional
Slice stride used when sampling the stack for histogram estimation;
1 uses every slice. Default is 1.
percent_th_low : float, optional
Lower intensity threshold as a CDF percentile of the stack histogram.
Pixels above this percentile are treated as NP candidates.
Default is 99.5.
percent_th_high : float, optional
Upper intensity threshold as a CDF percentile; pixels above this are
always excluded from the background histogram. Default is 99.8.
halo_pad : int, optional
Padding in pixels added around detected NP seeds when building the
background histogram. Default is 20.
min_size : int, optional
Minimum connected-component size in pixels for NP seed detection
during background extraction. Default is 20.
posterior_th : float, optional
Posterior probability threshold for binarization, used externally
by ``label_nanoparticles``. Default is 0.95.
beta_params : tuple[float, float], optional
Shape parameters (alpha, beta) of the Beta prior on the per-image
NP mixing weight ``pi``. Default is (1.0, 20.0).
zarr_chunks : tuple[int, int, int], optional
Chunk shape (Z, Y, X) for the posterior zarr array.
Default is (1, 1024, 1024).
model_name : str | None
Name of the fitted model. If None, the model name will be assigned
automatically as ``f"nps-{timestamp()}"``.
"""
root_path: Path
spacing_dir: "str"
verbose: bool = False
max_iter: int = 10000
eps: float = 1e-12
nll_tol: float = 1e-10
sampling_step: int = 1
percent_th_low: float = 99.5
percent_th_high: float = 99.8
halo_pad: int = 20
min_size: int = 20
posterior_th: float = 0.95
beta_params: tuple[float, float] = (1.0, 20.0)
zarr_chunks: tuple[int, int, int] = (1, 1024, 1024)
model_name: str | None = None
seg_target: str = field(init=False)
EXCLUDED_PROCESSING_FIELDS = set(
[
"root_path",
"spacing_dir",
"verbose",
"zarr_chunks",
]
)
def __post_init__(self) -> None:
super().__post_init__()
self.seg_target = "nps"
if self.model_name is None:
self.model_name = f"nps-{timestamp()}"
[docs]
class NanoparticleSegmentation:
"""Two-component intensity mixture model for nanoparticle segmentation.
Decomposes the voxel intensity histogram of an SBF-SEM stack into a
background distribution and a nanoparticle (NP) distribution using an
Expectation-Maximization algorithm. Per-slice posterior probabilities
are estimated by MAP optimization of a per-image mixing weight.
Parameters
----------
config : NanoparticleSegConfig
Segmentation configuration.
Attributes
----------
config : NanoparticleSegConfig
Configuration passed at construction.
stack_root : zarr.Group
Open zarr group at ``config.stack_root``.
volume_stack : zarr.Array
Image array at ``images/{config.spacing_dir}`` within the group.
p_stack : numpy.ndarray
Full-stack intensity PMF of shape (256,). Set by :meth:`fit`.
p_bg : numpy.ndarray
Background intensity PMF of shape (256,). Set by :meth:`fit`.
p_np : numpy.ndarray
Fitted NP intensity PMF of shape (256,). Set by :meth:`fit`.
hist_stack : numpy.ndarray
Raw intensity histogram of the stack, shape (256,). Set by :meth:`fit`.
hist_bg : numpy.ndarray
Raw background pixel histogram, shape (256,). Set by :meth:`fit`.
summary_fit : dict
EM convergence summary with keys ``w_bg`` (final background mixing
weight), ``nll`` (final negative log-likelihood), and ``nit``
(number of iterations). Set by :meth:`fit`.
Notes
-----
Call :meth:`fit` before :meth:`predict`. A fitted model can be
persisted with :meth:`save` and restored from disk with :meth:`load`.
"""
def __init__(
self,
config: NanoparticleSegConfig,
) -> None:
"""Open the zarr store and locate the image array.
Parameters
----------
config : NanoparticleSegConfig
Segmentation configuration.
"""
self.config = config
self.stack_root = zarr.open_group(self.config.root_path, mode="a")
self.volume_stack: zarr.Array = self.stack_root.get(
f"images/{self.config.spacing_dir}"
)
[docs]
@classmethod
def load(cls, model_dir: str | Path) -> Self:
"""Load a pretrained model from a directory. The directory is expected to have
a training_manifest.json and model_params.npz of the correct format"""
if isinstance(model_dir, str):
model_dir = Path(model_dir)
config = NanoparticleSegConfig.from_json(model_dir / "training_manifest.json")
segmentation = cls(config)
params = np.load(model_dir / "model_params.npz")
segmentation.p_np = params["p_np"]
segmentation.p_bg = params["p_bg"]
segmentation.p_stack = params["p_stack"]
segmentation.hist_stack = params["hist_stack"]
segmentation.hist_bg = params["hist_bg"]
return segmentation
[docs]
def save(self, target_dir: str | Path) -> None:
"""Save the calculated probabilities and the training parameters in the
specified directory as model_params.npz and training_manifest.json, respectively.
It also saves the summary of the fit results in fit_results.json"""
if isinstance(target_dir, str):
target_dir = Path(target_dir)
self.config.to_json(target_dir / "training_manifest.json")
np.savez(
target_dir / "model_params.npz",
p_np=self.p_np,
p_bg=self.p_bg,
p_stack=self.p_stack,
hist_stack=self.hist_stack,
hist_bg=self.hist_bg,
)
with open(target_dir / "fit_results.json", "w") as file:
json.dump(self.summary_fit, file, indent=4, cls=CustomJSONEncoder)
def _normalize_pmf(self, hist: np.ndarray) -> np.ndarray:
"""Normalize a count histogram to a probability mass function.
Converts raw counts to a PMF and adds a small floor (``config.eps``)
to every bin to avoid division by zero in log-likelihood calculations.
Parameters
----------
hist : numpy.ndarray
Integer count histogram of length 256.
Returns
-------
numpy.ndarray
Normalized PMF of length 256, dtype float64.
"""
pmf = hist.astype(np.float64)
if pmf.sum() == 0:
return np.full(256, 1 / 256, dtype=np.float64)
pmf /= pmf.sum()
# Add floor and re-normalize
pmf = np.maximum(pmf, self.config.eps)
pmf /= pmf.sum()
return pmf
def _calc_stack_hist(self) -> None:
"""Compute and store the intensity histogram of the full image stack.
Iterates over slices at intervals of ``config.sampling_step``, accumulates
a 256-bin uint8 histogram, normalizes it to a PMF, and stores the results
in ``self.hist_stack`` and ``self.p_stack``.
"""
hist_stack = np.zeros(256, dtype=np.int64)
for idx in tqdm(
range(0, self.volume_stack.shape[0], self.config.sampling_step),
"Calculating stack histogram",
disable=not self.config.verbose,
):
image = self.volume_stack[idx]
hist_stack += bincount_ubyte(image)
self.hist_stack = hist_stack
self.p_stack = self._normalize_pmf(hist_stack)
def _calc_bg_hist(self) -> None:
"""Compute and store the intensity histogram of background pixels.
Determines intensity thresholds from the stack PMF, then accumulates
background pixel counts across sampled slices using ``_extract_bg_hist``.
Stores results in ``self.hist_bg`` and ``self.p_bg``.
"""
hist_bg = np.zeros(256, dtype=np.int64)
self.th_low = np.searchsorted(
self.p_stack.cumsum(), self.config.percent_th_low / 100
)
self.th_high = np.searchsorted(
self.p_stack.cumsum(), self.config.percent_th_high / 100
)
for idx in tqdm(
range(0, self.volume_stack.shape[0], self.config.sampling_step),
desc="Calculating background histogram",
disable=not self.config.verbose,
):
image = self.volume_stack[idx]
hist_bg += self._extract_bg_hist(image)
self.hist_bg = hist_bg
self.p_bg = self._normalize_pmf(hist_bg)
@gpu_dispatch(return_to_host=True)
def _extract_bg_hist(self, image: ArrayLike) -> ArrayLike | None:
"""Extract the background pixel histogram from a single 2D slice.
Identifies candidate NP regions by intensity thresholding, expands their
bounding boxes by a halo, and counts intensity values only in pixels
outside those regions (background). Returns None if no background pixels
are found.
Parameters
----------
image : ArrayLike
Grayscale 2D image slice (uint8).
Returns
-------
ArrayLike | None
256-bin integer histogram of background pixels, or None if no
background was found in this slice.
"""
def find_objects_cpu(labels: ArrayLike) -> list[tuple[slice]]:
"""scipy.ndimage.find_objects is not yet implemented in cupy.
Fallback to CPU and handle moving inputs"""
labels = to_host(labels)
return find_objects(labels)
# Rough object seeds by intensity
mask = image > self.th_low
labels, num = ndi.label(mask, structure=np.ones((3, 3), dtype=xp.uint8))
if num > 0:
sizes = xp.bincount(labels.ravel())[1:]
keep = xp.nonzero(sizes >= self.config.min_size)[0]
bboxes = find_objects_cpu(labels)
obj_mask = xp.zeros(image.shape, dtype=bool)
for i in keep:
sl = self._expand_bbox(bboxes[int(i)], image.shape)
obj_mask[sl] = True
else:
obj_mask = xp.zeros(image.shape, dtype=bool)
# Hard background exclusion above upper intensity threshold
bright_mask = image >= self.th_high
bg_mask = (~obj_mask) & (~bright_mask)
if bg_mask.any():
return bincount_ubyte(image[bg_mask])
return
def _expand_bbox(self, bbox: tuple[slice], image_shape: tuple) -> tuple[slice]:
"""Expand a ``scipy.ndimage.find_objects`` bounding box by a halo margin.
Parameters
----------
bbox : tuple[slice]
Bounding box as returned by ``find_objects`` — a tuple of two
slices (Y, X).
image_shape : tuple
Shape of the 2D image, used to clip the expanded bbox to valid bounds.
Returns
-------
tuple[slice]
Expanded bounding box clipped to image boundaries.
"""
(slice_y, slice_x) = bbox
y0 = max(slice_y.start - self.config.halo_pad, 0)
y1 = min(slice_y.stop + self.config.halo_pad, image_shape[0])
x0 = max(slice_x.start - self.config.halo_pad, 0)
x1 = min(slice_x.stop + self.config.halo_pad, image_shape[1])
return (slice(y0, y1), slice(x0, x1))
def _init_w_bg(self) -> np.float64:
"""Initialize the background mixing weight for the EM algorithm.
Returns
-------
numpy.float64
Initial ``w_bg`` estimate: ratio of background pixel count to total
pixel count across the sampled stack.
"""
return self.hist_bg.sum() / self.hist_stack.sum()
def _init_p_np(self) -> np.ndarray:
"""Initialize the NP probability distribution for the EM algorithm.
Sets bins below the low-intensity threshold to zero and normalizes the
tail of the stack distribution to serve as the initial NP PMF.
Returns
-------
numpy.ndarray
Initial NP PMF of length 256, concentrated on high-intensity bins.
"""
p_np = np.zeros_like(self.p_stack, dtype=np.float64)
p_np[self.th_low :] = self.p_stack[self.th_low :]
return self._normalize_pmf(p_np)
def _deconvolve_mixture(self):
"""Fit the two-component intensity mixture model via EM.
Runs a plain EM algorithm that alternates between computing posterior
responsibilities (E-step) and updating the NP PMF and background weight
by minimizing the negative log-likelihood (M-step). Stores the fitted
NP PMF in ``self.p_np`` and convergence summary in ``self.summary_fit``.
"""
# Initialization
w_bg = self._init_w_bg()
p_np = self._init_p_np()
def neg_log_likelihood(w_bg: float, p_np: np.ndarray) -> np.float64:
"""Calculate observed negative data log likelihood"""
mix = w_bg * self.p_bg + (1 - w_bg) * p_np
return -np.sum(self.p_stack * np.log(mix + self.config.eps))
nll_prev = np.inf
for i in range(self.config.max_iter):
# E-step
num_bg = w_bg * self.p_bg
num_np = (1 - w_bg) * p_np
gamma_bg = num_bg / (num_bg + num_np + self.config.eps)
# M-step
q = (1.0 - gamma_bg) * self.p_stack
q_sum = q.sum()
if q.sum() > self.config.eps:
p_np = self._normalize_pmf(q)
w_bg = np.sum(self.p_stack * gamma_bg)
w_bg = np.clip(w_bg, self.config.eps, 1 - self.config.eps)
else:
raise RuntimeError(
f"EM collapsed at iteration {i}: foreground responsibility "
f"vanished (q_sum={q_sum:.2e}). Check initialization or input data."
)
nll = neg_log_likelihood(w_bg, p_np)
if -(nll - nll_prev) < self.config.nll_tol:
break
nll_prev = nll
self.p_np = p_np
self.summary_fit = {"w_bg": w_bg, "nll": nll, "nit": i}
[docs]
def fit(self) -> None:
"""Fit the two-component mixture model to the image stack.
Computes the full-stack intensity histogram (``p_stack``), extracts
the background histogram (``p_bg``), and runs EM to estimate the NP
intensity distribution (``p_np``). Populates ``hist_stack``,
``hist_bg``, ``p_stack``, ``p_bg``, ``p_np``, and ``summary_fit``.
Must be called before :meth:`predict`.
"""
self._calc_stack_hist()
self._calc_bg_hist()
self._deconvolve_mixture()
def _fit_pi(self, hist_image):
"""Estimate the per-image NP mixing weight by MAP optimization.
Minimizes the negative log-likelihood of the observed pixel histogram
under the fitted mixture model, regularized by a Beta prior on ``pi``.
Parameters
----------
hist_image : numpy.ndarray
256-bin integer histogram of the image to be analyzed.
Returns
-------
float
MAP estimate of the NP mixing weight ``pi`` in [0, 1].
"""
def nll_beta_prior(pi):
mix = pi * self.p_np + (1 - pi) * self.p_bg
ll = np.sum(hist_image * np.log(mix + self.config.eps))
beta_prior = (
(self.config.beta_params[0] - 1) * np.log(pi + self.config.eps)
) + (self.config.beta_params[1] - 1) * np.log(1 - pi + self.config.eps)
ll += beta_prior
return -ll
res = minimize_scalar(nll_beta_prior, bounds=(0.0, 1.0), method="bounded")
return float(res.x)
@gpu_dispatch(return_to_host=True)
def _posterior_image(self, image: ArrayLike) -> tuple[np.ndarray, np.ndarray]:
"""Predict the NP posterior distribution of the given 2D image.
Parameters
----------
image : np.ndarray
Grayscale 2D image to be analyzed.
Returns
-------
np.ndarray
The raw map of the posterior NP distribution, in range [0, 1].
np.ndarray
The NP posterior distribution, in range [0, 1].
"""
hist_image = bincount_ubyte(image)
pi = self._fit_pi(hist_image)
mix = pi * self.p_np + (1 - pi) * self.p_bg
posterior_bins = to_device((pi * self.p_np) / (mix + self.config.eps))
posterior_map = posterior_bins[image]
return posterior_map, posterior_bins
[docs]
def predict(self) -> None:
"""Compute and save the per-voxel NP posterior probability map.
For each slice, estimates the per-image NP mixing weight ``pi`` by MAP
optimization, then computes the per-pixel posterior probability under
the fitted mixture. Results are written to
``labels/nps/posterior/{spacing_dir}`` in the zarr store as a float16
array.
Notes
-----
Requires a fitted model; call :meth:`fit` or :meth:`load` first.
"""
dst_zarr = _create_zarr_array(
root=self.stack_root,
dst_path=f"labels/nps/posterior/{self.config.spacing_dir}",
shape=self.volume_stack.shape,
chunks=self.config.chunks,
dtype="f2",
)
for idx in tqdm(
range(self.volume_stack.shape[0]),
desc="Calculating posterior",
disable=not self.config.verbose,
):
image = self.volume_stack[idx]
posterior, _ = self._posterior_image(image)
dst_zarr[idx] = posterior
processing = ProcessingStep.from_config("segmentation", self.config)
_write_zarr_metadata(
root=self.stack_root,
dst_zarr=dst_zarr,
src_zarr=self.volume_stack,
processing=processing,
)
[docs]
def label_nanoparticles(
root_path: Path,
spacing: tuple[int, int, int],
threshold: float,
radius: int = 1,
connectivity: int = 2,
min_size: int = 10,
) -> None:
"""Threshold posterior >= threshold and label nanoparticle binary masks.
Masks are first subject to a binary closing with given radius. Then, labeling is
done with the specified connectivity and labels smaller than min_size (in voxels)
are discarded. The volume is finally relabeled to have sequential label IDs.
Parameters
----------
root_path : Path
Path to the root of the zarr store
spacing : tuple[int, int, int]
Spacing to use for the labeling
threshold : float
Theshold to apply to the nanoparticle posterior, such that posterior >= threshold.
It should be between 0 and 1.
radius : int
Radius in voxels for a ball element using during binary closing.
Default is 1.
connectivity : int
Connectivity used during labeling. Default is 2.
min_size : int
Minimimum label size in voxels. Labels with size < min_size will be discarded.
Default is 10.
Notes
-----
The posterior array is loaded fully into memory. For large volumes, ensure
sufficient RAM is available before calling this function.
"""
root = zarr.open_group(root_path)
spacing_dir = dirname_from_spacing(spacing)
posterior_zarr: zarr.Array = root.get(f"labels/nps/posterior/{spacing_dir}")
# Threshold posterior
posterior = posterior_zarr[:]
masks = posterior >= threshold
# Label masks
masks_filt = binary_closing(masks, radius=radius)
masks_filt = ski_cpu.measure.label(masks_filt, connectivity=connectivity)
masks_filt = filter_and_relabel(masks_filt, min_size=min_size)
write_zarr(
root,
masks_filt,
dst_path=f"labels/nps/masks/{spacing_dir}",
src_zarr=posterior_zarr,
processing={
"step": "labeling",
"radius": radius,
"connectivity": connectivity,
"min_size": min_size,
},
)