Source code for sphero_vem.metrics

"""
This module contains losses and metrics used throughout the library
"""

from functools import partial
import torch
import torch.nn.functional as F
from kornia.losses import ssim_loss


[docs] def ncc_loss(img1: torch.Tensor, img2: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: """Normalized cross-correlation loss between two image tensors. Parameters ---------- img1 : torch.Tensor First input image tensor of shape (N, C, H, W). img2 : torch.Tensor Second input image tensor. Must have the same shape as *img1*. eps : float, optional Small constant added to the denominator for numerical stability. Default is 1e-6. Returns ------- torch.Tensor Scalar NCC loss in [0, 2]; 0 for perfectly correlated images. """ img1_c = img1 - img1.mean(dim=[2, 3], keepdim=True) img2_c = img2 - img2.mean(dim=[2, 3], keepdim=True) num = (img1_c * img2_c).mean(dim=[2, 3]) den = ( (img1_c.pow(2).mean(dim=[2, 3]) * img2_c.pow(2).mean(dim=[2, 3])) .clamp_min(0) .sqrt() ) return 1.0 - (num / (den + eps)).mean()
[docs] class LossDispatcher: """Factory that resolves a loss function by name and forwards calls to it. Available loss names: ``"mse"``, ``"mae"``, ``"ncc"``, ``"ssim"``. Parameters ---------- loss_name : str Name of the loss function to use. Case-sensitive. Raises ------ ValueError If *loss_name* is not in the registry. """ _losses = { "mse": F.mse_loss, "mae": partial(F.l1_loss, reduction="mean"), "ncc": ncc_loss, "ssim": ssim_loss, } def __init__(self, loss_name: str): """Resolve *loss_name* to a callable loss function. Parameters ---------- loss_name : str Name of the loss function. Must be one of the keys in ``_losses``. Raises ------ ValueError If *loss_name* is not a registered loss name. """ try: self._fun = self._losses[loss_name] except KeyError: raise ValueError( f"Invalid loss: '{loss_name}'. " f"Available losses are: {list(self._losses.keys())}" ) def __call__(self, *args, **kwargs): """Compute the loss by forwarding all arguments to the resolved function. Parameters ---------- *args Positional arguments forwarded to the loss function. **kwargs Keyword arguments forwarded to the loss function. Returns ------- torch.Tensor Loss value returned by the underlying loss function. """ return self._fun(*args, **kwargs)