Source code for detectools.formats.segmentation_format
from __future__ import annotations
from typing import Any, Dict, List, Literal, Tuple, Union
import torch
from detectools import Task
from detectools.formats.detection_format import DetectionAnnotation, DetectionFormat
from detectools.formats.mask_utils import cocopolygons2mask, cocoseg2masks, mask2polygons
from torch import Tensor, is_floating_point
from torch.nn.functional import one_hot
from torchvision.ops import masks_to_boxes
from torchvision.transforms.v2 import ConvertBoundingBoxFormat
from torchvision.transforms.v2.functional import crop_mask, pad_mask
from torchvision.tv_tensors import BoundingBoxes, Mask
from detectools.formats.detect_mask import DetectMask
[docs]
class SegmentationAnnotation(DetectionAnnotation):
"""BaseAnnotation child class for SegmentationAnnotation task.
Attributes:
-----------
Attributes:
boxe (``BoundingBoxes``): Boxe coordinates in XYWH format.
label (``Tensor``): Class label of object.
spatial_size (``Tuple[int, int]``): Size of corresponding image (H, W)
score (``Tensor``): Confidence score of the object (for prediction).
mask (``Tensor``): Object segmentation binary mask (H, W).
Methods:
-----------
"""
boxe : BoundingBoxes
label : Tensor
spatial_size: Tuple[int, int]
score: Tensor
mask: Tensor
def __init__(
self,
spatial_size: Tuple[int],
label: Tensor,
boxe: Tensor,
mask: Tensor,
score: Tensor = None,
):
# assert Task mode is "instance_segmentation"
assert (
Task.mode == "instance_segmentation"
), f"Task mode should be 'instance_segmentation' to create SegmentationAnnotation, got {Task.mode}"
# assert mask is a stacked class maks
super().__init__(spatial_size, label, boxe, score)
self.mask = mask
[docs]
def object_to_coco(
self, annotation_id: int = 1, image_id: int = 1
) -> Dict[str, Any]:
"""Return instance segmentation annotation data as COCO like dict.
Args:
annotation_id (``int``, **optional**): Id of the annotation. Defaults to 1.
image_id (``int``, **optional**): Id of the corresponding image. Defaults to 1.
Returns:
``Dict[str, Any]``:
- COCO like dict with annotation instance data.
"""
# convert mask tensor to polygons
polygons, area = mask2polygons(self.mask)
annotation = {
"id": annotation_id,
"bbox": self.boxe.tolist(),
"segmentation": polygons,
"area": area,
"category_id": self.label.item(),
"image_id": image_id,
}
if self.score:
annotation["score"] = self.score.item()
return annotation
[docs]
class SegmentationFormat(DetectionFormat):
"""BaseFormat child class for instance segmentation task.
Args:
spatial_size (``Tensor``): Spatial size (H, W) of corresponding images.
labels (``Tensor``): Tensor of shape (N,) with class labels for each object.
boxes (``Tensor``): Tensor of shape (N, 4). N for N objects and 4 for boxes coordinates.
scores (``Tensor``, **optional**): Tensor of shape (N,) with objects confidence score. Defaults to None.
box_format (``Literal['XYWH', 'XYXY', 'CXCYWH']``, **optional**): Format of bounding boxes. Defaults to 'XYWH'.
masks (``Tensor``): Tensor of shape (H,W) with values from 0 to N, one value/object.
Attributes:
-----------
Attributes:
box_format (``Literal["XYWH", "XYXY", "CXCYWH"]``): Format of bounding boxes.
spatial_size (``Tuple[int, int]``): Size of corresponding image (H, W)
size (``int``): Number of objects in BaseFormat.
data: (``Dict[str, Tensor]``): Data dict that contains objects informations in it's keys (labels, boxes, scores, masks).
Methods:
-----------
"""
spatial_size: Tuple[
int, int
] # Store the H, W image size corresponding to objects boxes/masks stored in BaseFormat.
data: Dict[
str, Tensor
] # Store all values (labels, boxes/masks at least) corresponding to objects in an image.
size: int # Number of objects in image.
box_format: Literal["XYWH", "XYXY", "CXCYWH"] # format for bounding boxes.
# override
[docs]
def empty(spatial_size: Tuple[int], device: Literal["cpu", "cuda"] = "cpu") -> SegmentationFormat:
"""Return an empty instance SegmentationFormat.
Args:
spatial_size (``Tuple[int]``): Size (H, W) of the corresponding image.
device (``Literal["cpu", "cuda"]``): Device to define format on. Default to "cpu".
Returns:
``SegmentationFormat``:
- SegmentationFormat instance.
"""
boxes = torch.tensor([[]]).to(device)
labels = torch.tensor([]).to(device)
masks = Mask(torch.zeros((spatial_size)).int()).to(device) # mask full of 0
segmentation_format = SegmentationFormat(
spatial_size=spatial_size, labels=labels, boxes=boxes, masks=masks
)
return segmentation_format
# override
[docs]
def from_coco(
coco_annotations: List[Dict[str, Any]], spatial_size: Tuple[int]
) -> SegmentationFormat:
"""Return SegmentationFormat from an image COCO data dictionnary.
Args:
coco_annotations (``List[Dict[str, Any]]``): Coco data dictionnary.
spatial_size (``Tuple[int]``): Size (H, W) of the corresponding image.
Returns:
``SegmentationFormat``:
- SegmentationFormat instance.
"""
boxes = torch.tensor([ann["bbox"] for ann in coco_annotations])
labels = torch.tensor([ann["category_id"] for ann in coco_annotations])
masks: DetectMask = DetectMask(torch.zeros(spatial_size))
# assign obj id to DetecMask
for i, ann in enumerate(coco_annotations):
ann_mask = cocoseg2masks(ann["segmentation"], spatial_size)
masks.add_binary_mask(ann_mask.int(), i)
# remove objects with no masks (also overwritted masks)
keep_indexes = masks.reindex()
boxes = boxes[keep_indexes]
labels = labels[keep_indexes]
# build SegmentationFormat
segmentation_format = SegmentationFormat(spatial_size, labels, boxes, masks)
return segmentation_format
def __init__(
self,
spatial_size: Tensor,
labels: Tensor,
boxes: Tensor,
masks: Tensor,
scores: Tensor = None,
box_format: Literal["XYWH", "XYXY", "CXCYWH"] = "XYWH",
):
# assert Task mode is "instance_segmentation"
assert (
Task.mode == "instance_segmentation"
), f"Task mode should be 'instance_segmentation' to create SegmentationFormat, got {Task.mode}"
# create DetectMask
if not isinstance(masks, DetectMask):
masks = DetectMask(masks)
if labels.nelement():
keep_indexes = masks.reindex()
boxes = boxes[keep_indexes]
labels = labels[keep_indexes]
# store all data in data dict
boxes = BoundingBoxes(boxes.int(), canvas_size=spatial_size, format=box_format)
self.data: Dict[str, Tensor] = {
"boxes": boxes,
"labels": labels,
"masks": masks,
}
if isinstance(scores, Tensor):
self.data["scores"] = scores[keep_indexes]
self.box_format = box_format
self.size = labels.nelement()
self.spatial_size = spatial_size
[docs]
def set(self, key: str, value: Tensor):
"""Set a new pair of key/value. Value should be of shape (N, ...) with N == self.size.
if key is "masks" and value is binary masks (N, H, W), size is N, if value is stacked mask (H,W). size is unstacked_masks.
Args:
key (``str``): Key of value to set.
value (``Tensor``): Data as tensor.
"""
# get shape of new value and assert it's equal to self.size
if key == "masks":
if not isinstance(value, DetectMask):
value: DetectMask = DetectMask(value)
data_size = value.n_objects
else: # value is not mask
data_size = value.size()[0] if value.nelement() else 0
assert (
data_size == self.size
), f"New value size should be equal to self.size, got {data_size} and {self.size}."
# assign value to key with correct device
device = self.get_device()
value = value.to(device)
self.data[key] = value
[docs]
def get_object(self, indice: int) -> SegmentationAnnotation:
"""Return SegmentationAnnotation object at position indice.
Args:
indice (``int``): Position of object to gather.
Returns:
``SegmentationAnnotation``:
- SegmentationAnnotation instance.
"""
single_object_format = self[indice]
bbox, label, mask = single_object_format.get("boxes", "labels", "masks")
segmentation_object = SegmentationAnnotation(
self.spatial_size, label, bbox.squeeze(), mask._mask
)
if "scores" in single_object_format:
segmentation_object.score = single_object_format.get("scores")
return segmentation_object
# Methods that changes internal states of Formats
[docs]
def crop(self, top: int, left: int, height: int, width: int):
"""Crop boxes and mask from top corner pixel and update spatial size.
Args:
top (``int``): Position to crop from top border.
left (``int``): Position to crop from left border.
height (``int``): height of the crop.
width (``int``): Width of the crop.
"""
self.spatial_size = (height, width)
if self.size == 0:
return self
super().crop(top, left, height, width)
masks: DetectMask = self.get("masks")
cropped_masks = crop_mask(masks._mask, top=top, left=left, height=height, width=width)
cropped_masks = DetectMask(cropped_masks)
keep_indexes = cropped_masks.reindex()
self.__dict__ = self[keep_indexes].__dict__.copy() # only copy dict to avoid duplicating object and loose masks change
self.set("masks", cropped_masks)
[docs]
def pad(self, left: int, top: int, right: int, bottom: int):
"""Pad boxes and mask and update spatial size.
Args:
left (``int``): Pad value on left border.
top (``int``): Pad value on top border.
right (``int``): Pad value on right border.
bottom (``int``): Pad value on bottom border.
"""
super().pad(left, top, right, bottom)
if self.size == 0:
return
masks: DetectMask = self.get("masks")
padded_masks = pad_mask(masks._mask, padding=[left, top, right, bottom])
self.set("masks", padded_masks)
[docs]
def rescale_boxes_from_masks(self):
"""Iter over objects, for each masks compute all objects contours:
- If there is only one contour rescale the box with the mask contour.
- If there is more than one object duplicate label to have one mask, box and label for one object.
"""
masks: DetectMask = self.get("masks")
new_boxes = []
new_masks: DetectMask = DetectMask(torch.zeros(self.spatial_size))
# store indexes of objetc to duplicate for other data than masks and boxes
duplicates_indexes = []
for i, mask in enumerate(masks):
# get coco polygons from mask
polygons, _ = mask2polygons(mask._mask)
# iter over polygons/object
for polygon in polygons:
sub_mask = cocopolygons2mask([polygon], self.spatial_size)
box = masks_to_boxes(sub_mask.unsqueeze(0)).squeeze(0)
new_boxes.append(box)
new_masks.add_binary_mask(sub_mask)
duplicates_indexes.append(i)
# assign values
new_boxes = BoundingBoxes(
torch.stack(new_boxes), canvas_size=self.spatial_size, format="XYXY"
)
format_converter = ConvertBoundingBoxFormat(self.box_format)
new_boxes: BoundingBoxes = format_converter(new_boxes)
self.data["boxes"] = new_boxes
self.data["masks"] = new_masks
# add other data with duplicates
for key, value in self.data.items():
if key in ["boxes", "masks"]:
continue
self.data[key] = value[duplicates_indexes]
# change format size
self.size = new_boxes.shape[0]