Source code for detectools.models.yolov8_seg

from math import ceil, floor
from typing import Any, Dict, List, Literal, Tuple, Union

import torch
import torchvision.transforms.v2.functional as F
from detectools import Task
from detectools.formats import BatchedFormats, SegmentationFormat
from detectools.models.base import BaseModel
from torch import Tensor
from torchvision.ops import nms
from torchvision.transforms.v2 import ConvertBoundingBoxFormat
from torchvision.tv_tensors import BoundingBoxes
from ultralytics.cfg import get_cfg
from ultralytics.nn.tasks import SegmentationModel, attempt_load_one_weight
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils.ops import scale_masks
from detectools.formats.detect_mask import DetectMask


[docs] class Yolov8Segmentation(SegmentationModel, BaseModel): """YOLO segmentation model class in detectools. This class inheriths from SegmentationModel_ (Ultralytics) and BaseModel (detectools). Load yolo architecture from ultralytics repository. If pretrained load a pretrain model from ultralytics. .. _SegmentationModel: https://docs.ultralytics.com/reference/nn/tasks/?h=detectionmodel#ultralytics.nn.tasks.SegmentationModel.__init__ Args: architecture (``str``, **optional**): Architecture to use to build YOLO model. Check Ultralytics availables architectures_ . Defaults to "yolov8m". num_classes (``int``, **optional**): Number of classes in the task. Defaults to 1. pretrained (``bool``, **optional**): To use pretrained weights. Defaults to True. confidence_thr (``float``, **optional**): Confidence score threshold to consider object as true prediction. Defaults to 0.5. max_detection (``int``, **optional**): Maximum number of object to predict on one image. Defaults to 300. nms_threshold (``float``, **optional**): IoU threshold to consider 2 boxes as overlapping for Non Max Suppression algorithm.. Defaults to 0.45. .. _architectures: https://docs.ultralytics.com/models/yolov8/#supported-tasks-and-modes Attributes: ----------- Attributes: confidence_thr (``float``): Confidence score threshold to consider object as true prediction. max_detection (``int``): Maximum number of object to predict on one image. nms_threshold (``float``): IoU threshold to consider 2 boxes as overlapping for Non Max Suppression algorithm. num_classes (``int``): Number of classes. Methods: ----------- """ def __init__( self, architecture: str = "yolov8n-seg", pretrained=True, confidence_thr: float = 0.5, max_detection: int = 300, nms_threshold: float = 0.45, num_classes: int = 1, *args, **kwargs, ): # assert Task mode is "instance_segmentation" assert ( Task.mode == "instance_segmentation" ), f"Task mode should be 'instance_segmentation' to construct Yolov8Segmentation object, got {Task.mode}" # build model from ultralytics config super().__init__(f"{architecture}.yaml", nc=num_classes, *args, **kwargs) self.args = get_cfg(DEFAULT_CFG) self.criterion = self.init_criterion() self.num_classes = num_classes self.confidence_thr = confidence_thr self.max_detection = max_detection self.nms_threshold = nms_threshold # load weights from ultralytics repo if pretrained if pretrained: architecture = attempt_load_one_weight( f"{architecture}.pt", ) self.load(architecture[0])
[docs] def to_device(self, device: Literal["cpu", "cuda"]): """Send model & criterion to device. Args: device (``Literal['cpu', 'cuda']``): Device to send model on. """ self.to(device) self.criterion = self.init_criterion()
[docs] def prepare( self, images: Tensor, targets: BatchedFormats = None ) -> Union[Any, Tuple[Any]]: """Transform images and targets into YOLO specific format for prediction & loss computation. Args: images (``Tensor``): Batch images. targets (``BatchedFormats``, **optional**): Batched targets from DetectionDataset. Returns: ``Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]``: - Images data prepared for YOLO. - If targets: images + targets prepared for YOLO. """ (left, top, right, bottom) = self.yolo_pad_requirements(images) # pad images & target images = F.pad(images, list((left, top, right, bottom))) if targets: prepared_targets = targets.clone() # prepare targets for yolo prepared_targets.apply("pad", left, top, right, bottom) prepared_targets = self.prepare_target(prepared_targets) return images, prepared_targets else: return images
[docs] def build_results( self, raw_output: Tuple[Tensor, ...], ) -> BatchedFormats: """Transform model outputs into Batch SegmentationFormat for results. Args: raw_outputs (``List[Tensor]``): Model outputs. prebuild_outputs (``Tensor``): Extracted boxes from YOLO raw outputs. Returns: ``BatchedFormats``: - Batched predictions. """ # TODO reduce size of this function by splitting in smaller ones # extract informations from raw_results boxes, cls_scores, mask_weights, protos = self.prebuild_output(raw_output) device = boxes.device # gather device & spatila_size spatial_size = self.retrieve_spatial_size(raw_output) # init converter for nms box_converter = ConvertBoundingBoxFormat("XYXY") results = [] # iter over image results for i, image_boxes in enumerate(boxes): # get image values image_boxes = boxes[i] image_cls_scores = cls_scores[i] image_mask_weights = mask_weights[i] image_protos = protos[i] # get best class and corresponding score image_cls_scores, best_class = torch.max(image_cls_scores, dim=1) # filter by confidence thr confidence_indexes = torch.nonzero( image_cls_scores > self.confidence_thr ).squeeze() # if only 1 value unsqueeze first dimension to get sequence if confidence_indexes.nelement() == 1: confidence_indexes = confidence_indexes.unsqueeze(0) # apply confidence to all values image_boxes = image_boxes[confidence_indexes] image_cls_scores = image_cls_scores[confidence_indexes] image_mask_weights = image_mask_weights[confidence_indexes] image_labels = best_class[confidence_indexes] # if no objects with good confidence return empty DetectionFormat if image_labels.nelement() == 0: result = SegmentationFormat.empty(spatial_size, device=device) results.append(result) continue # apply NMS on boxes to retrieve non overlapped objects image_boxes = BoundingBoxes( image_boxes, canvas_size=spatial_size, format="CXCYWH", ) # send to xyxy image_boxes = box_converter(image_boxes) nms_indexes = nms(image_boxes, image_cls_scores, self.nms_threshold) # apply nms to all values image_boxes = image_boxes[nms_indexes] image_cls_scores = image_cls_scores[nms_indexes] image_mask_weights = image_mask_weights[nms_indexes] image_labels = image_labels[nms_indexes] # select N objects (N== self.max_detections) with highest scores indexes = torch.argsort(image_cls_scores) image_boxes = image_boxes[indexes][-self.max_detection:] image_cls_scores = image_cls_scores[indexes][-self.max_detection:] image_mask_weights = image_mask_weights[indexes][-self.max_detection:] image_labels = image_labels[indexes][-self.max_detection:] # compute binary masks per remaining obj image_masks = self.proto2mask( image_protos, image_mask_weights, image_boxes, spatial_size ) # apply "logits" thresholding to mask (logit > 0.5 belong to object) # TODO pass this to model attribute image_masks = image_masks.gt_(0.5) # create DetectMask and remove objects with no mask segmentation_mask: DetectMask = DetectMask.from_binary_masks(image_masks.int()) keep_indexes = segmentation_mask.reindex() if not keep_indexes.nelement(): result = SegmentationFormat.empty(spatial_size) results.append(result) continue image_boxes = image_boxes[keep_indexes] image_cls_scores = image_cls_scores[keep_indexes] image_labels = image_labels[keep_indexes] # create SegmentationFormat result = SegmentationFormat( spatial_size, image_labels, image_boxes, segmentation_mask, scores=image_cls_scores, box_format="XYXY", ) # send boxes to xywh result.set_boxes_format("XYWH") results.append(result) if len(results) == 0: results = [SegmentationFormat.empty(spatial_size)] results = BatchedFormats(results) return results
[docs] def run_forward( self, images: Tensor, targets: BatchedFormats, predict: bool = False ) -> Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], BatchedFormats]]: """Compute loss from images and if target passed, compute loss & return both loss dict and results. Args: images (``Tensor``): Batch RGB images. targets (``BatchedFormats``): Batch targets. predict (``bool``, **optional**): To return predictions or not. Defaults to False. Returns: ``Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], BatchedFormats]]``: - Loss dict. - If predict: predictions. """ assert predict == ( not self.training ), f"Model mode should be equal to predict boolean, got {self.training} & {predict}" # prepare inputs prepared_images, prepared_targets = self.prepare(images, targets=targets) # run forward pass if self.training: raw_outputs = self(prepared_images) else: raw_outputs = self(prepared_images) # compute loss loss_dict = self.compute_loss(raw_outputs, prepared_targets) # return predictions if needed if predict: predictions = self.build_results(raw_outputs) left, top, _, _ = self.yolo_pad_requirements(images) h, w = images.shape[-2:] predictions.apply("crop", top, left, h, w) return loss_dict, predictions else: return loss_dict
[docs] def get_predictions(self, images: Tensor) -> BatchedFormats: """Prepare images, Apply YOLO forward pass and build results. Args: images (``Tensor``): RGB images Tensor. Returns: ``BatchedFormats``: - Predictions for images as BatchedFormats. """ self.eval() # get original spatial size ori_h, ori_w = images.shape[-2:] # pad images images, (left, top, _, _) = self.prepare_image(images) # predict raw_outputs = self(images) results = self.build_results(raw_outputs) # crop to back at original spatial size results.apply("crop", top, left, ori_h, ori_w) return results
[docs] def prebuild_output(self, raw_output: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: """Unpack Yolov8-seg (eval mode) raw results. Args: raw_output (``Tuple[Tensor, ...]``): Yolov8 raw eval mode results. Returns: ``Tuple[Tensor, ...]``: - boxes (N_batch, N_obj, cxcywh). - cls_scores (N_batch, N_cls). - mask_weights (N_batch, N_obj, 32). - protos (N_batch, protos). """ output0, output1 = raw_output output0 = output0.permute(0, 2, 1) # permute in N_batch, N_obj, obj_length boxes = output0[:, :, 0:4] cls_indx = 4 + self.num_classes cls_scores = output0[:, :, 4:cls_indx] mask_weights = output0[:, :, -32:] protos = output1[2] return boxes, cls_scores, mask_weights, protos
[docs] def prepare_image(self, images: Tensor) -> Tuple[Tensor, int]: """Pad images if needed & return padding values. Args: images (``Tensor``): Batch_images. Returns: ``Tuple[Tensor, Tuple[int]]``: - Padded images. - Padding values. """ # get borders padding coordinates = self.yolo_pad_requirements(images) # pad images images = F.pad(images, list(coordinates)) return images, coordinates
[docs] def prepare_target(self, target: BatchedFormats) -> Dict[str, Tensor]: """Transform SegmentationFormat targets into yolo-seg targets format. Args: targets (``BatchedFormats``): Batch targets. Returns: ``Dict[str, Tensor]``: - Targets in YOLO format. """ target.apply("set_boxes_format", "CXCYWH") target.apply("normalize") targets: List[SegmentationFormat] = target.split() boxes = torch.cat([t.get("boxes") for t in targets]) labels = torch.cat([t.get("labels") for t in targets]) device = labels.device masks = torch.stack([t.get("masks")._mask for t in targets]) images_indices = torch.cat( [torch.full((t.size,), i, device=device) for i, t in enumerate(targets)] ) # put labels and batch_idx in yolo dormat : Tensor (N, 1) batch_idx = images_indices[:, None] classes = labels[:, None] yolotarget = { "masks": masks, "bboxes": boxes, "cls": classes, "batch_idx": batch_idx, } return yolotarget
[docs] def yolo_pad_requirements( self, input_object: Union[Tensor, SegmentationFormat] ) -> Tuple[int, ...]: """Return values for padding to fit 'divisible by 32' requirement. Args: input_object (``Union[Tensor, DetectionFormat]``): Input to pad (image or DetectionFormat). Returns: ``List[int]``: - Padding values. """ # get spatial size if isinstance(input_object, SegmentationFormat): h, w = input_object.spatial_size elif isinstance(input_object, Tensor): h, w = input_object.shape[-2:] # (H,W) # get pad values diff_h, diff_w = h % 32, w % 32 pad_h = 32 - diff_h if diff_h > 0 else 0 pad_w = 32 - diff_w if diff_w > 0 else 0 # define padding for each border if pad_h or pad_w: half_h, half_w = pad_h / 2, pad_w / 2 left, top, right, bottom = ( ceil(half_w), ceil(half_h), floor(half_w), floor(half_h), ) else: left, top, right, bottom = (0, 0, 0, 0) return (left, top, right, bottom)
[docs] def retrieve_spatial_size(self, raw_outputs: List[Tensor]) -> Tuple[int, int]: """Retrieve image shape from raw_outputs and stride values. Args: raw_outputs (``List[Tensor]``): Raw ouptuts from YOLO model. Returns: ``Tuple[int]``: - Size of input image (H, W). """ if self.training: h = int(raw_outputs[0][0].shape[-2] * self.stride[0]) w = int(raw_outputs[0][0].shape[-1] * self.stride[0]) else: h = int(raw_outputs[1][0][0].shape[-2] * self.stride[0]) w = int(raw_outputs[1][0][0].shape[-1] * self.stride[0]) return (h, w)
[docs] def compute_loss(self, predictions: Tuple, target: Dict) -> Dict[str, Tensor]: """Compute loss with predictions & targets. Args: raw_outputs (``Any``): Raw output of model. targets (``DetectionFormat``): Targets in YOLO format. Returns: ``Dict[str, Tensor]``: - Loss dict with total loss (key: "loss") & sublosses. """ loss, loss_detail = self.criterion(predictions, target) loss_dict = { "loss": loss, "loss_box": loss_detail[0], "loss_seg": loss_detail[1], "loss_cls": loss_detail[2], "loss_dfl": loss_detail[3], } return loss_dict
[docs] def mask2yolo(self, mask: Tensor) -> Tensor: """Convert stacked binary to yolo mask, i.e (1, h, w) with values in [0, ... , Nobjs] This shape is suitable for yolov8 loss. Args: mask (``Tensor``): Stacked binary mask (N, H, W). Returns: ``Tensor``: - YOLO segmentation mask. """ if mask.ndim < 3: mask = mask[None, :] reindexing = torch.tensor(range(1, mask.shape[0] + 1)).to(mask.device) # convert to yolomask: stacked h, w with values in [0, ..., Nobjs], 0 being absence of object yolomask, _ = torch.max(mask * reindexing[:, None, None], dim=0) return yolomask[None, :]
[docs] def proto2mask( self, protos: Tensor, weights: Tensor, boxes: Tensor, shape: Tuple[int] ) -> Tensor: """Combine protos and weights to get masks, then crop instances from boxes (Useful in predictions). Args: protos (``Tensor``): Sub masks (32, ...). weights (``Tensor``): YOLO mask weights (32, ...). boxes (``Tensor``): Boxes (N, 4) in XYXY format. shape (``Tuple[int]``): Original image size (H, W). Returns: ``Tensor``: - YOLO segmentation mask. """ c, mh, mw = protos.shape # CHW masks = (weights @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) masks: Tensor = scale_masks(masks[None], shape)[0] # CHW for m in range(masks.shape[0]): xl, yl, xr, yr = boxes.int()[m] masks[m, 0:yl, :] = 0 masks[m, yr:, :] = 0 masks[m, :, 0:xl] = 0 masks[m, :, xr:] = 0 return masks