from typing import Any, Callable, Dict, List, Tuple
import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.classification import StatScores
from detectools import Task
from detectools.formats import Format
from detectools.metrics.functionnals import match_boxes
[docs]
class DetectMetric(Metric):
"""Base class for custom detection metric with torchmetrics engine."""
is_differentiable = None
higher_is_better = True
full_state_update: bool = False
def __init__(
self,
func: Callable,
iou_threshold: float = 0.5, # to consider as TP or FP
name: str = "DetectionMetric",
**kwargs,
):
super().__init__(**kwargs)
self.add_state("tp", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("fp", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("fn", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("samplewise", default=[], dist_reduce_fx="cat")
self.func = func
self.iou_threshold = iou_threshold
self.name = name
[docs]
def update(self, prediction: Format, target: Format):
"""Update internal states of metric."""
# if no objects in prediction or target: stats are neutral (0)
# else compute match box and stats
if prediction.size != 0 and target.size != 0:
tp, fp, fn, _ = match_boxes(prediction, target, self.iou_threshold)
elif prediction.size != 0 and target.size == 0:
tp, fp, fn, _ = 0, prediction.size, 0, None
elif prediction.size == 0 and target.size != 0:
tp, fp, fn, _ = 0, 0, target.size, None
else:
return
self.tp += tp
self.fp += fp
self.fn += fn
# None because of no tn in detection
self.samplewise.append(self.func(tp, fp, None, fn))
[docs]
def compute(self) -> Dict[str, Tensor]:
"""Return metric computed with internal state.
Returns:
Dict[str, Tensor]: _description_
"""
global_value = self.func(self.tp, self.fp, None, self.fn)
samplewise_value = torch.nanmean(torch.tensor(self.samplewise))
metric_dict = {"global": global_value, "samplewise": samplewise_value}
return metric_dict
[docs]
class ClassifMetric(Metric):
"""Child class of torchmetrics metrics for classification.
Allow to take Format as inputs and return dict of metric."""
def __init__(
self,
func: Callable, # metric functionnal
num_classes: int = 1,
name: str = "ClassifMetric",
**kwargs: Any,
):
super().__init__(**kwargs)
self.func = func
self.task = "binary" if num_classes == 1 else "multiclass"
self.nc = num_classes
# use tm engine to get statistics (tp,tn,fp,fn,sup)
self.stat_score = StatScores(
task=self.task,
multidim_average="samplewise",
average="none",
num_classes=num_classes,
)
self.add_state("stats", default=[], dist_reduce_fx="cat")
self.name = name
[docs]
def update(self, prediction: Format, target: Format):
"""Update internal states."""
# if no predictions or target, no classification evaluation
if prediction.size == 0 or target.size == 0:
return
target_labels = target.get("labels")
# match boxes
_, _, _, (pred_idxs, target_idxs) = match_boxes(prediction, target)
pred_idxs = pred_idxs.to(prediction.get_device())
target_idxs = target_idxs.to(target.get_device())
prediction = prediction[pred_idxs]
target = target[target_idxs]
# if no box match, add all targets in fn
if prediction.size == 0:
device = target.get("boxes").device
class_stats = torch.zeros((self.nc, 5)).to(device)
values = torch.tensor(
[torch.sum(target_labels == i) for i in range(self.nc)]
)
class_stats[:, 4] = values
self.stats.append(class_stats[None, ...])
return
# get labels
pred_labels = prediction.get("labels")
target_labels = target.get("labels")
# if binary pass label 0 to 1
if self.task == "binary":
pred_labels += 1
target_labels += 1
# compute stats
stats = self.stat_score(pred_labels[None, ...], target_labels[None, ...]).view(
1, self.nc, 5
)
self.stats.append(stats)
[docs]
def global_micro_compute(self) -> Tensor:
"""Compute metric with global/micro averagging."""
n_samples = len(self.stats)
samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5)
# sum stats accross samples
samples_stack = samples_stack.sum(dim=0) # (NC, 5)
# sum stats accross classes
micro_stack = torch.sum(samples_stack, dim=0) # (5,)
# compute metric
tp, fp, tn, fn, _ = micro_stack.unbind(0)
return self.func(tp, fp, tn, fn)
[docs]
def global_macro_compute(self) -> Tuple[Tensor, Tensor]:
"""Compute metric with global/macro averraging.
Return also metric/class tensor."""
n_samples = len(self.stats)
samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5)
# sum stats accross samples
samples_stack = samples_stack.sum(dim=0) # (NC, 5)
# compute metric/class
tp, fp, tn, fn, _ = samples_stack.unbind(1)
class_metrics = self.func(tp, fp, tn, fn) # (NC,)
return torch.nanmean(class_metrics), class_metrics
[docs]
def samplewise_micro(self) -> Tensor:
"""Compute metric with samplewise/micro averagging."""
n_samples = len(self.stats)
samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5))
# sum stat accross classes
samples_stack = samples_stack.sum(dim=1) # (NC, 5)
# compute metric/sample
tp, fp, tn, fn, _ = samples_stack.unbind(dim=1)
samples_metrics = self.func(tp, fp, tn, fn)
# mean accross samples
return torch.nanmean(samples_metrics)
[docs]
def samplewise_macro(self) -> Tensor:
"""Compute metric with samplewise/macro averagging."""
n_samples = len(self.stats)
samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5)
# compute metric/class/sample
tp, fp, tn, fn, _ = samples_stack.unbind(2)
class_metrics = self.func(tp, fp, tn, fn) # (N,NC)
# mean accross classes
macro = torch.nanmean(class_metrics, dim=1) # (N,)
# mean accross samples
macro_samplewise = torch.nanmean(macro, dim=0)
return macro_samplewise
[docs]
def compute(self):
"""Comput metric with all averag strategy and return a dict with all values."""
if not self.stats:
return {self.name: torch.tensor(torch.nan)}
# if binary no need for macro aggregation
metric_dict = {}
if self.task == "multiclass":
# global micro
global_micro = self.global_micro_compute()
metric_dict.update({"_global_micro": global_micro})
# global macro
global_macro, class_metrics = self.global_macro_compute()
classes_dict = {
f"/cls_{i}": class_metrics[i] for i in range(class_metrics.nelement())
}
metric_dict.update({"_global_macro": global_macro})
metric_dict.update(classes_dict)
# samplewise micro
samplewise_micro = self.samplewise_micro()
metric_dict.update({"_samplewise_micro": samplewise_micro})
# samplewise macro
samplewise_macro = self.samplewise_macro()
metric_dict.update({"_samplewise_macro": samplewise_macro})
else:
# global micro
global_micro = self.global_micro_compute()
metric_dict.update({"_global": global_micro})
# samplewise micro
samplewise_micro = self.samplewise_micro()
metric_dict.update({"_samplewise": samplewise_micro})
return metric_dict
return metric_dict
[docs]
class SemanticSegmentationMetric(ClassifMetric):
"""Child class of classification metric. Move from instance to semantic segmentation
paradigm to provide stats based on classes masks (instead of objects).
"""
def __init__(
self,
func: Callable, # metric functionnal
num_classes: int = 1,
name: str = "SegmentationMetric",
**kwargs: Any,
):
assert (
Task.mode == "instance_segmentation"
), "lib mode must be instance_segmentation to use SemanticSegmentationMetric, please set mode by using Task.set_lib_mode() at the begining of script "
# init from ClassifMetric and
# redefine stat score num_classes for multiclass detection to include background as a class (+1 class)
# Note that background is later removed in update (background is meaningless regarding instance segmentation) but important for stats calculations
super().__init__(func, num_classes=num_classes, name=name, **kwargs)
if num_classes > 1:
self.stat_score.num_classes = num_classes + 1
# override
[docs]
def update(self, prediction: Format, target: Format):
"""Convert target & prediction to semantic mask to compute stats in semantic segmentation paradigm. Update internal state."""
# handling empty target
if target.size == 0:
semantic_target_mask = torch.zeros(target.spatial_size).to(self.device)
else:
target_mask = target.get("masks")._mask.clone()
target_labels = target.get("labels").clone()
target_obj_idx = torch.tensor(range(1, target_labels.nelement() + 1)).to(
self.device
)
# convert object stacked mask to semantic
semantic_target_mask = (
torch.zeros(target_mask.shape).to(self.device).to(target_labels.dtype)
)
for cls in target_labels.unique():
obj_filt = target_obj_idx[target_labels == cls]
# cls +1 because detect classes start at 0 and 0 is background in semantic
semantic_target_mask[torch.isin(target_mask, obj_filt)] = cls + 1
# handling empty prediction
if prediction.size == 0:
semantic_prediction_mask = torch.zeros(prediction.spatial_size).to(
self.device
)
else:
prediction_mask = prediction.get("masks")._mask.clone()
prediction_labels = prediction.get("labels").clone()
prediction_obj_idx = torch.tensor(
range(1, prediction_labels.nelement() + 1)
).to(self.device)
# convert object stacked mask to semantic : all objects with given class replaced by their class
semantic_prediction_mask = (
torch.zeros(prediction_mask.shape)
.to(self.device)
.to(prediction_labels.dtype)
)
for cls in prediction_labels.unique():
obj_filt = prediction_obj_idx[prediction_labels == cls]
# cls +1 because detect classes start at 0 and 0 is background in semantic
semantic_prediction_mask[torch.isin(prediction_mask, obj_filt)] = (
cls + 1
)
# add dummy dim for stat score
flatpred = torch.flatten(semantic_prediction_mask)[None, :]
flattarget = torch.flatten(semantic_target_mask)[None, :]
if self.task == "binary":
self.stats.append(self.stat_score(flatpred, flattarget))
else:
# for multiclass will remove the background scores to avoid crazy high scores
self.stats.append(self.stat_score(flatpred, flattarget)[:, 1:, :])
[docs]
class MetricCompose:
"""Wrap multiple DetectTM in a same object:
- Merge metrics in same dictionnary output
- Allow all metrics to be displayed on same graph.
"""
def __init__(self, metrics: List[Metric], name="MetricCompose"):
self.metrics = metrics
self.name = name
[docs]
def update(self, prediction: Format, target: Format):
"""Update all metrics."""
for metric in self.metrics:
metric.update(prediction, target)
[docs]
def compute(self):
"""Compute on each metrics and merge metric dicts together."""
metric_dict = {}
for metric in self.metrics:
subdict = metric.compute()
subdict = {f"{metric.name}_{key}": value for key, value in subdict.items()}
metric_dict.update(subdict)
return metric_dict
[docs]
def reset(self):
"""Reset all metrics."""
for metric in self.metrics:
metric.reset()
def __call__(self, prediction: Format, target: Format):
"""Call each metric and merge outptu dicts."""
metric_dict = {}
for metric in self.metrics:
subdict = metric(prediction, target)
subdict = {f"{metric.name}_{key}": value for key, value in subdict}
metric_dict.update(subdict)
return metric_dict