"""
This module contains functions used to finetune Cellpose models
"""
import os
from pathlib import Path
from dataclasses import dataclass, field, asdict
import re
import json
import logging
import numpy as np
from sklearn.model_selection import train_test_split
import wandb
from cellpose import models, train, io
from tifffile import imread
from sphero_vem.utils import timestamp
from sphero_vem.io import write_image
from sphero_vem.utils import BaseConfig
[docs]
@dataclass
class CellposeFinetuneConfig(BaseConfig):
"""Configuration for fine-tuning a Cellpose-SAM model.
Parameters
----------
dir_labeled : Path | str
Directory containing labeled training images and a ``manifest.json``
with ``"spacing"`` and ``"processing"`` keys.
learning_rate : float, optional
Initial learning rate. Default is 5e-5.
batch_size : int, optional
Training batch size. Default is 8.
n_epochs : int, optional
Number of training epochs. Default is 100.
test_size : float, optional
Fraction of labeled images reserved for testing. Default is 0.2.
random_state : int, optional
Random seed for the train/test split. Default is 42.
seg_target : str, optional
Segmentation target: ``"cells"`` or ``"nuclei"``. Default is
``"cells"``.
save_predictions : bool, optional
Save model predictions on test images after training. Default is False.
use_bfloat16 : bool, optional
Use bfloat16 mixed precision during training. Default is True.
"""
dir_labeled: Path | str
learning_rate: float = 5e-5
batch_size: int = 8
n_epochs: int = 100
test_size: float = 0.2
random_state: int = 42
seg_target: str = "cells"
save_predictions: bool = False
use_bfloat16: bool = True
# Parameters that are initialized by post_init
model_name: str = field(init=False)
dir_experiment: Path = field(init=False)
dir_predictions: Path = field(init=False)
spacing: list = field(init=False)
def __post_init__(self):
"""Set ``wandb_project``, ``model_name``, output directories, and ``spacing``."""
if self.seg_target == "cells":
self.wandb_project = "cell-segmentation"
elif self.seg_target == "nuclei":
self.wandb_project = "nuclei-segmentation"
self.model_name = f"cellposeSAM-{self.seg_target}-{timestamp()}"
self.dir_experiment = Path(f"data/models/cellpose/{self.model_name}")
self.dir_predictions = Path(
f"data/processed/segmented/finetuning/{self.model_name}"
)
# Load processing and spacing from manifest
with open(self.dir_labeled / "manifest.json") as file:
manifest = json.load(file)
self.spacing = manifest.get("spacing")
def _generate_training_manifest(
config: CellposeFinetuneConfig,
train_files: list[Path],
test_files: list[Path],
) -> None:
"""Write and upload the training manifest JSON.
Saves a JSON file recording experiment metadata, train/test file lists,
and processing history to ``config.dir_experiment``, then uploads it to
Weights & Biases.
Parameters
----------
config : CellposeFinetuneConfig
Fine-tuning configuration.
train_files : list[Path]
Paths to training image files.
test_files : list[Path]
Paths to test image files.
"""
# Load processing
with open(config.dir_labeled / "manifest.json") as file:
manifest = json.load(file)
processing = manifest.get("processing", [])
training_manifest = {
"experiment_id": config.model_name,
"timestamp": timestamp(),
"segmentation_target": config.seg_target,
"learning_rate": config.learning_rate,
"batch_size": config.batch_size,
"n_epochs": config.n_epochs,
"processing": processing,
"spacing": config.spacing,
"train_files": [str(path) for path in train_files],
"test_files": [str(path) for path in test_files],
}
# Save manifest locally and send to WandB
manifest_path = config.dir_experiment / "training_manifest.json"
with open(manifest_path, "w") as f:
json.dump(training_manifest, f, indent=4)
wandb.save(manifest_path)
class _CellposeLogHandler(logging.Handler):
"""Logging handler that parses Cellpose training log lines and logs to WandB.
Intercepts the ``cellpose.train`` logger, extracts epoch/loss/LR values
via regex, and forwards them to ``wandb.log``.
"""
def __init__(self):
"""Initialize the handler with default logging settings."""
super().__init__()
def emit(self, record):
"""Parse a log record and upload matching metrics to Weights & Biases.
Parameters
----------
record : logging.LogRecord
Log record from the ``cellpose.train`` logger.
"""
message = record.getMessage()
pattern = r"(\d+), train_loss=([\d\.]+), test_loss=([\d\.]+), LR=([\d\.e\-\+]+)"
match = re.search(pattern, message)
if match:
epoch = int(match.group(1))
train_loss = float(match.group(2))
test_loss = float(match.group(3))
learning_rate = float(match.group(4))
wandb.log(
{
"epoch": epoch,
"train_loss": train_loss,
"test_loss": test_loss,
"learning_rate": learning_rate,
}
)
[docs]
class CellposeLogger:
"""Context manager for Cellpose training logging via Weights & Biases.
Initializes a WandB run, attaches a ``_CellposeLogHandler`` to the
``cellpose.train`` logger, and provides cleanup utilities.
Parameters
----------
config : CellposeFinetuneConfig
Fine-tuning configuration used to initialize the WandB run.
"""
def __init__(self, config: CellposeFinetuneConfig) -> None:
"""Set up WandB and attach the Cellpose log handler.
Parameters
----------
config : CellposeFinetuneConfig
Fine-tuning configuration.
"""
# Activate Cellpose logging
io.logger_setup()
self._init_wandb(config)
# Add WandB handler to the cellpose logger
self.wandb_handler = _CellposeLogHandler()
self.cellpose_logger = logging.getLogger("cellpose.train")
self.cellpose_logger.addHandler(self.wandb_handler)
def _init_wandb(self, config: CellposeFinetuneConfig) -> None:
"""Initialize WandB logging"""
wandb_api_key = os.getenv("WANDB_API_KEY")
wandb.login(key=wandb_api_key)
wandb.init(
project=config.wandb_project,
name=config.model_name,
dir=config.dir_experiment,
)
wandb.config.update(asdict(config))
# Save config to dir and upload to wandb
config_path = config.dir_experiment / "config.json"
config.to_json(config_path)
wandb.save(config_path)
[docs]
def stop(self) -> None:
"""Stop logging and cleanup"""
self.cellpose_logger.removeHandler(self.wandb_handler)
wandb.finish()
[docs]
def save_losses(self, train_losses: list[float], test_losses: list[float]) -> None:
"""Log detailed epoch-by-epoch data"""
for epoch, (train_loss, test_loss) in enumerate(zip(train_losses, test_losses)):
wandb.log(
{
"epoch": epoch,
"train_loss_epoch": train_loss,
"test_loss_epoch": test_loss if test_loss > 0 else np.nan,
}
)
def _split_dataset(config: CellposeFinetuneConfig) -> tuple[list[Path], list[Path]]:
"""Split segmentation data into train and test datasets. This function only
considers images that also have labels
Labels are expected to be in a 'labels' subdirectory and have the naming
'{image_name}-{config.seg_target}'.
Example: with config.seg_target='cells', a valid image/labels pair is:
- image.tif
- labels/image-cells.tif
"""
# Ensure an even split between different imaging planes, if present
train_files = []
test_files = []
for axis in ["x", "y", "z"]:
image_list = [
path
for path in config.dir_labeled.glob(f"*-{axis}_*.tif")
if _labels_path(config, path).exists()
]
if image_list != []:
train_slices, test_slices = train_test_split(
image_list, test_size=config.test_size, random_state=config.random_state
)
train_files += train_slices
test_files += test_slices
return train_files, test_files
def _load_data(
config: CellposeFinetuneConfig, image_files: list[Path]
) -> tuple[list[np.ndarray], list[np.ndarray]]:
"""Load images for training/testing as a list of arrays with corresponing labels.
Labels are expected to be in a 'labels' subdirectory and have the naming
'{image_name}-{config.seg_target}'.
Example: with config.seg_target='cells', a valid image/labels pair is:
- image.tif
- labels/image-cells.tif
Parameters
----------
config : CellposeConfig
Cellpose configuration object.
image_files : list[Path]
List of paths to the image files to load.
Returns
-------
tuple[list[np.ndarray], list[np.ndarray]]
A tuple containing two lists:
- First list: loaded images as numpy arrays
- Second list: loaded label masks as numpy arrays
"""
data = [imread(path) for path in image_files]
labels_files = [_labels_path(config, path) for path in image_files]
labels = [imread(path) for path in labels_files]
return data, labels
def _labels_path(config: CellposeFinetuneConfig, image_path: Path) -> Path:
"""Construct the expected label file path for a given image.
Parameters
----------
config : CellposeFinetuneConfig
Fine-tuning configuration providing ``dir_labeled`` and ``seg_target``.
image_path : Path
Path to the source image file.
Returns
-------
Path
Expected path of the corresponding label TIFF under
``config.dir_labeled/labels/``.
"""
return config.dir_labeled / f"labels/{image_path.stem}-{config.seg_target}.tif"
[docs]
def finetune_cellpose(config: CellposeFinetuneConfig):
"""
Finetune a Cellpose model using the parameters in the configuration.
This function handles the complete fine-tuning process for a Cellpose model,
including data splitting and logging.
Parameters
----------
config : CellposeConfig
Configuration object containing all necessary parameters for fine-tuning.
"""
logger = CellposeLogger(config)
train_files, test_files = _split_dataset(config)
_generate_training_manifest(config, train_files, test_files)
cellpose_model = models.CellposeModel(gpu=True, use_bfloat16=config.use_bfloat16)
train_data, train_labels = _load_data(config, train_files)
test_data, test_labels = _load_data(config, test_files)
_, train_losses, test_losses = train.train_seg(
net=cellpose_model.net,
train_data=train_data,
train_labels=train_labels,
test_data=test_data,
test_labels=test_labels,
learning_rate=config.learning_rate,
batch_size=config.batch_size,
n_epochs=config.n_epochs,
model_name=config.model_name,
save_path=config.dir_experiment,
)
logger.save_losses(train_losses, test_losses)
logger.stop()
# Save test predictions
if config.save_predictions:
for i, image in enumerate(test_data):
masks = cellpose_model.eval(image)
write_image(
config.dir_predictions
/ f"{test_files[i].stem}-{config.seg_target}.tif",
masks[0],
compressed=True,
)