from typing import Dict
import detectools.metrics.functionnals as F
from detectools.formats import BaseFormat
from detectools.metrics.base import (ClassifMetric, DetectMetric,
SemanticSegmentationMetric)
from torch import Tensor
from torchmetrics.detection import MeanAveragePrecision
[docs]
class DetectF1score(DetectMetric):
"""F1 score for detection task.
Args:
iou_threshold (``float``): IoU threshold to consider taht prediction and target boxes match. Default to 0.5.
"""
def __init__(self, *args, **kwargs):
super().__init__(func=F.f1score, name="DetectF1score", *args, **kwargs)
[docs]
class DetectPrecision(DetectMetric):
"""Precision for detection task.
Args:
iou_threshold (``float``): IoU threshold to consider taht prediction and target boxes match. Default to 0.5.
"""
def __init__(self, *args, **kwargs):
super().__init__(func=F.precision, name="DetectPrecision", *args, **kwargs)
[docs]
class DetectRecall(DetectMetric):
"""Recall for detection task.
Args:
iou_threshold (``float``): IoU threshold to consider taht prediction and target boxes match. Default to 0.5.
"""
def __init__(self, *args, **kwargs):
super().__init__(func=F.recall, name="DetectRecall", *args, **kwargs)
[docs]
class MeanAP(MeanAveragePrecision):
"""Compute Mean Average Precision (from torchmetrics MAP_ ).
.. _MAP:
https://lightning.ai/docs/torchmetrics/stable/detection/mean_average_precision.html
"""
def __init__(self, *args, **kwargs):
super().__init__(self, *args, **kwargs)
self.name = "MeanAP"
[docs]
def update(self, prediction: BaseFormat, target: BaseFormat):
"""Prepare inputs and call MAP.
Args:
prediction (``BaseFormat``): Predictions.
target (``BaseFormat``): Targets.
"""
prediction = self.prepare_input(prediction)
target = self.prepare_input(target)
super().update(prediction, target)
## classification metrics
[docs]
class ClassifF1score(ClassifMetric):
"""F1 score for classification task.
Args:
num_classes (``int``): Number of classes for the task.
"""
def __init__(self, *args, **kwargs):
super().__init__(func=F.f1score, name="ClassifF1score", *args, **kwargs)
[docs]
class ClassifPrecision(ClassifMetric):
"""F1 score for classification task.
Args:
num_classes (``int``): Number of classes for the task.
"""
def __init__(self, *args, **kwargs):
super().__init__(func=F.precision, name="ClassifPrecision", *args, **kwargs)
[docs]
class ClassifRecall(ClassifMetric):
"""F1 score for classification task.
Args:
num_classes (``int``): Number of classes for the task.
"""
def __init__(self, *args, **kwargs):
super().__init__(func=F.recall, name="ClassifRecall", *args, **kwargs)
## semantic segmentation metrics
[docs]
class SemanticF1score(SemanticSegmentationMetric):
"""F1 score for semantic segmentation task.
Args:
num_classes (``int``): Number of classes for the task.
"""
def __init__(self, *args, **kwargs):
super().__init__(func=F.f1score, name="SemanticF1score", *args, **kwargs)
[docs]
class SemanticIoU(SemanticSegmentationMetric):
"""IoU for semantic segmentation task.
Args:
num_classes (``int``): Number of classes for the task.
"""
def __init__(self, *args, **kwargs):
super().__init__(func=F.iou, name="SemanticIoU", *args, **kwargs)
[docs]
class SemanticAccuracy(SemanticSegmentationMetric):
"""Accuracy for semantic segmentation task.
Args:
num_classes (``int``): Number of classes for the task.
"""
def __init__(self, *args, **kwargs):
super().__init__(func=F.accuracy, name="SemanticAccuracy", *args, **kwargs)