from typing import Tuple
import torch
from detectools.formats import Format
from torch import Tensor
from torchvision.ops import box_iou
[docs]
def match_boxes(
prediction: Format, target: Format, iou_threshold: float = 0.5
) -> Tuple[Tensor, Tensor, Tensor, Tuple[Tensor, Tensor]]:
"""Match better prediction boxes candidates with target boxes. Return indexes of
prediction and target boxes that match and compute statistics of detection quality (Tp, FP, FN).
Args:
prediction (Format): Prediction.
target (Format): Target.
iou_threshold (float, optional): IoU threshold to discard some matchs with overlapping < to thr. Defaults to 0.5.
Returns:
Tuple[Tensor, Tensor, Tensor, Tuple[Tensor, Tensor]]: Detection statsitics (TP, FN, FN)
& prediction and target boxes indexes that match well.
"""
# set box format to XYXY for torch box iou computation
assert (
prediction.box_format == target.box_format
), f"Predictio and taget should have the same box_format, got {prediction.box_format} & {target.box_format}"
origin_box_format = prediction.box_format
prediction.set_boxes_format("XYXY")
target.set_boxes_format("XYXY")
# extract boxes
pred_boxes = prediction.get("boxes")
target_boxes = target.get("boxes")
# compute cross matrix of ious
cross_ious = box_iou(pred_boxes, target_boxes)
# boolean matrix of max pred iou == max target iou --> true positives candidates
max_matchs = (
torch.max(cross_ious, dim=1)[0][..., None]
== torch.max(cross_ious, dim=0)[0][None, ...]
).view(cross_ious.shape)
# true positive if iou of max_matchs > iou threshold
tp = torch.sum((max_matchs > 0) & (cross_ious > iou_threshold))
# false positive: all boxes with no match with targets
fp = torch.sum(pred_boxes.shape[0] - torch.sum(tp))
# false negative if target has no pred box with iou > threshold
fn = torch.sum(torch.max(cross_ious, dim=0)[0] < 0.5)
# tp pairs index
pred_idxs, target_idxs = torch.where(
(max_matchs > 0) & (cross_ious > iou_threshold)
)
# extract indexes
pred_idxs = pred_idxs.tolist() if pred_idxs.nelement() > 0 else []
target_idxs = target_idxs.tolist() if target_idxs.nelement() > 0 else []
match_idxs = (torch.tensor(pred_idxs).long(), torch.tensor(target_idxs).long())
# send back box format to original format
prediction.set_boxes_format(origin_box_format)
target.set_boxes_format(origin_box_format)
return tp, fp, fn, match_idxs
# functionnals metrics
[docs]
def f1score(tp, fp, tn, fn):
return (2 * tp) / (2 * tp + fp + fn)
[docs]
def precision(tp, fp, tn, fn):
if tp + fp == 0:
return torch.tensor(torch.nan)
return (tp) / (tp + fp)
[docs]
def recall(tp, fp, tn, fn):
return (tp) / (tp + fn)
[docs]
def iou(tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> Tensor:
"""Compute IoU from statistics."""
return tp / (tp + fp + fn)
[docs]
def accuracy(tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor):
"""Compute accuracy from statistics."""
return (tp + tn) / (tn + tp + fp + fn)