# Copyright (C) 2021-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
from enum import Enum, auto
from itertools import zip_longest
from typing import (
Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union,
)
from attr import asdict, attrs, field
from typing_extensions import Literal
import attr
import numpy as np
from datumaro.util.attrs_util import default_if_none, not_empty
[docs]class AnnotationType(Enum):
label = auto()
mask = auto()
points = auto()
polygon = auto()
polyline = auto()
bbox = auto()
caption = auto()
cuboid_3d = auto()
COORDINATE_ROUNDING_DIGITS = 2
NO_GROUP = 0
[docs]@attrs(slots=True, kw_only=True, order=False)
class Annotation:
"""
A base annotation class.
Derived classes must define the '_type' class variable with a value
from the AnnotationType enum.
"""
# Describes an identifier of the annotation
# Is not required to be unique within DatasetItem annotations or dataset
id: int = field(default=0, validator=default_if_none(int))
# Arbitrary annotation-specific attributes. Typically, includes
# metainfo and properties that are not covered by other fields.
# If possible, try to limit value types of values by the simple
# builtin types (int, float, bool, str) to increase compatibility with
# different formats.
# There are some established names for common attributes like:
# - "occluded" (bool)
# - "visible" (bool)
# Possible dataset attributes can be described in Categories.attributes.
attributes: Dict[str, Any] = field(
factory=dict, validator=default_if_none(dict))
# Annotations can be grouped, which means they describe parts of a
# single object. The value of 0 means there is no group.
group: int = field(default=NO_GROUP, validator=default_if_none(int))
@property
def type(self) -> AnnotationType:
return self._type # must be set in subclasses
[docs] def as_dict(self) -> Dict[str, Any]:
"Returns a dictionary { field_name: value }"
return asdict(self)
[docs] def wrap(self, **kwargs):
"Returns a modified copy of the object"
return attr.evolve(self, **kwargs)
[docs]@attrs(slots=True, kw_only=True, order=False)
class Categories:
"""
A base class for annotation metainfo. It is supposed to include
dataset-wide metainfo like available labels, label colors,
label attributes etc.
"""
# Describes the list of possible annotation-type specific attributes
# in a dataset.
attributes: Set[str] = field(
factory=set, validator=default_if_none(set), eq=False)
[docs]@attrs(slots=True, order=False)
class LabelCategories(Categories):
[docs] @attrs(slots=True, order=False)
class Category:
name: str = field(converter=str, validator=not_empty)
parent: str = field(default='', validator=default_if_none(str))
attributes: Set[str] = field(
factory=set, validator=default_if_none(set))
items: List[str] = field(factory=list, validator=default_if_none(list))
_indices: Dict[str, int] = field(factory=dict, init=False, eq=False)
[docs] @classmethod
def from_iterable(cls, iterable: Iterable[Union[
str,
Tuple[str],
Tuple[str, str],
Tuple[str, str, List[str]],
]]) -> LabelCategories:
"""
Creates a LabelCategories from iterable.
Args:
iterable: This iterable object can be:
- a list of str - will be interpreted as list of Category names
- a list of positional arguments - will generate Categories
with these arguments
Returns: a LabelCategories object
"""
temp_categories = cls()
for category in iterable:
if isinstance(category, str):
category = [category]
temp_categories.add(*category)
return temp_categories
def __attrs_post_init__(self):
self._reindex()
def _reindex(self):
indices = {}
for index, item in enumerate(self.items):
assert item.name not in self._indices
indices[item.name] = index
self._indices = indices
[docs] def add(self, name: str, parent: Optional[str] = None,
attributes: Optional[Set[str]] = None) -> int:
assert name
assert name not in self._indices, name
index = len(self.items)
self.items.append(self.Category(name, parent, attributes))
self._indices[name] = index
return index
[docs] def find(self, name: str) -> Tuple[Optional[int], Optional[Category]]:
index = self._indices.get(name)
if index is not None:
return index, self.items[index]
return index, None
[docs] def __getitem__(self, idx: int) -> Category:
return self.items[idx]
[docs] def __contains__(self, value: Union[int, str]) -> bool:
if isinstance(value, str):
return self.find(value)[1] is not None
else:
return 0 <= value and value < len(self.items)
[docs] def __len__(self) -> int:
return len(self.items)
[docs] def __iter__(self) -> Iterator[Category]:
return iter(self.items)
[docs]@attrs(slots=True, order=False)
class Label(Annotation):
_type = AnnotationType.label
label: int = field(converter=int)
RgbColor = Tuple[int, int, int]
Colormap = Dict[int, RgbColor]
[docs]@attrs(slots=True, eq=False, order=False)
class MaskCategories(Categories):
"""
Describes a color map for segmentation masks.
"""
[docs] @classmethod
def generate(cls, size: int = 255, include_background: bool = True) \
-> MaskCategories:
"""
Generates MaskCategories with the specified size.
If include_background is True, the result will include the item
"0: (0, 0, 0)", which is typically used as a background color.
"""
from datumaro.util.mask_tools import generate_colormap
return cls(generate_colormap(size,
include_background=include_background))
colormap: Colormap = field(
factory=dict, validator=default_if_none(dict))
_inverse_colormap: Optional[Dict[RgbColor, int]] = field(
default=None, validator=attr.validators.optional(dict))
@property
def inverse_colormap(self) -> Dict[RgbColor, int]:
from datumaro.util.mask_tools import invert_colormap
if self._inverse_colormap is None:
if self.colormap is not None:
self._inverse_colormap = invert_colormap(self.colormap)
return self._inverse_colormap
[docs] def __contains__(self, idx: int) -> bool:
return idx in self.colormap
[docs] def __getitem__(self, idx: int) -> RgbColor:
return self.colormap[idx]
[docs] def __len__(self) -> int:
return len(self.colormap)
[docs] def __eq__(self, other):
if not super().__eq__(other):
return False
if not isinstance(other, __class__):
return False
for label_id, my_color in self.colormap.items():
other_color = other.colormap.get(label_id)
if not np.array_equal(my_color, other_color):
return False
return True
BinaryMaskImage = np.ndarray # 2d array of type bool
IndexMaskImage = np.ndarray # 2d array of type int
[docs]@attrs(slots=True, eq=False, order=False)
class Mask(Annotation):
"""
Represents a 2d single-instance binary segmentation mask.
"""
_type = AnnotationType.mask
_image = field()
label: Optional[int] = field(
converter=attr.converters.optional(int), default=None, kw_only=True)
z_order: int = field(
default=0, validator=default_if_none(int), kw_only=True)
def __attrs_post_init__(self):
if isinstance(self._image, np.ndarray):
self._image = self._image.astype(bool)
@property
def image(self) -> BinaryMaskImage:
image = self._image
if callable(image):
image = image()
return image
[docs] def as_class_mask(self, label_id: Optional[int] = None) -> IndexMaskImage:
"""
Produces a class index mask. Mask label id can be changed.
"""
if label_id is None:
label_id = self.label
from datumaro.util.mask_tools import make_index_mask
return make_index_mask(self.image, label_id)
[docs] def as_instance_mask(self, instance_id: int) -> IndexMaskImage:
"""
Produces a instance index mask.
"""
from datumaro.util.mask_tools import make_index_mask
return make_index_mask(self.image, instance_id)
[docs] def get_area(self) -> int:
return np.count_nonzero(self.image)
[docs] def get_bbox(self) -> Tuple[int, int, int, int]:
"""
Computes the bounding box of the mask.
Returns: [x, y, w, h]
"""
from datumaro.util.mask_tools import find_mask_bbox
return find_mask_bbox(self.image)
[docs] def paint(self, colormap: Colormap) -> np.ndarray:
"""
Applies a colormap to the mask and produces the resulting image.
"""
from datumaro.util.mask_tools import paint_mask
return paint_mask(self.as_class_mask(), colormap)
[docs] def __eq__(self, other):
if not super().__eq__(other):
return False
if not isinstance(other, __class__):
return False
return \
(self.label == other.label) and \
(self.z_order == other.z_order) and \
(np.array_equal(self.image, other.image))
[docs]@attrs(slots=True, eq=False, order=False)
class RleMask(Mask):
"""
An RLE-encoded instance segmentation mask.
"""
_rle = field() # uses pycocotools RLE representation
_image = field(init=False, default=None)
@property
def image(self) -> BinaryMaskImage:
return self._decode(self.rle)
@property
def rle(self):
rle = self._rle
if callable(rle):
rle = rle()
return rle
@staticmethod
def _decode(rle):
from pycocotools import mask as mask_utils
return mask_utils.decode(rle)
[docs] def get_area(self) -> int:
from pycocotools import mask as mask_utils
return mask_utils.area(self.rle)
[docs] def get_bbox(self) -> Tuple[int, int, int, int]:
from pycocotools import mask as mask_utils
return mask_utils.toBbox(self.rle)
[docs] def __eq__(self, other):
if not isinstance(other, __class__):
return super().__eq__(other)
return self.rle == other.rle
CompiledMaskImage = np.ndarray # 2d of integers (of different precision)
[docs]class CompiledMask:
"""
Represents class- and instance- segmentation masks with
all the instances (opposed to single-instance masks).
"""
[docs] @staticmethod
def from_instance_masks(instance_masks: Iterable[Mask],
instance_ids: Optional[Iterable[int]] = None,
instance_labels: Optional[Iterable[int]] = None) -> CompiledMask:
"""
Joins instance masks into a single mask. Masks are sorted by
z_order (ascending) prior to merging.
Parameters:
instance_ids: Instance id values for the produced instance mask.
By default, mask positions are used.
instance_labels: Instance label id values for the produced class
mask. By default, mask labels are used.
"""
from datumaro.util.mask_tools import make_index_mask
instance_ids = instance_ids or []
instance_labels = instance_labels or []
masks = sorted(
zip_longest(instance_masks, instance_ids, instance_labels),
key=lambda m: m[0].z_order)
max_index = len(masks) + 1
index_dtype = np.min_scalar_type(max_index)
masks = ((
m, 1 + i,
id if id is not None else 1 + i,
label if label is not None else m.label
) for i, (m, id, label) in enumerate(masks))
# This optimized version is supposed for:
# 1. Avoiding memory explosion on materialization of all masks
# 2. Optimizing mask materialization calls (RLE decoding)
# 3. Optimizing intermediate mask memory use
#
# Basically, a mask can be quite large (e.g. 10k x 10k @ int32 etc.),
# so we can only afford having just few copies in
# memory simultaneously.
it = iter(masks)
instance_map = [0]
class_map = [0]
m, idx, instance_id, class_id = next(it)
if not class_id:
idx = 0
index_mask = make_index_mask(m.image, idx, dtype=index_dtype)
instance_map.append(instance_id)
class_map.append(class_id)
for m, idx, instance_id, class_id in it:
if not class_id:
idx = 0
index_mask = np.where(m.image, idx, index_mask)
instance_map.append(instance_id)
class_map.append(class_id)
# Generate compiled masks
if np.array_equal(instance_map, range(max_index)):
merged_instance_mask = index_mask
else:
merged_instance_mask = np.array(instance_map,
dtype=np.min_scalar_type(instance_map))[index_mask]
merged_class_mask = np.array(class_map,
dtype=np.min_scalar_type(class_map))[index_mask]
return __class__(class_mask=merged_class_mask,
instance_mask=merged_instance_mask)
[docs] def __init__(self,
class_mask: Union[None,
CompiledMaskImage,
Callable[[], CompiledMaskImage]
] = None,
instance_mask: Union[None,
CompiledMaskImage,
Callable[[], CompiledMaskImage]
] = None):
self._class_mask = class_mask
self._instance_mask = instance_mask
@staticmethod
def _get_image(image):
if callable(image):
return image()
return image
@property
def class_mask(self) -> Optional[CompiledMaskImage]:
return self._get_image(self._class_mask)
@property
def instance_mask(self) -> Optional[CompiledMaskImage]:
return self._get_image(self._instance_mask)
@property
def instance_count(self) -> int:
return int(self.instance_mask.max())
[docs] def get_instance_labels(self) -> Dict[int, int]:
"""
Matches the class and instance masks.
Returns: { instance id: class id }
"""
class_shift = 16
m = (self.class_mask.astype(np.uint32) << class_shift) \
+ self.instance_mask.astype(np.uint32)
keys = np.unique(m)
instance_labels = {
k & ((1 << class_shift) - 1): k >> class_shift
for k in keys
if k & ((1 << class_shift) - 1) != 0
}
return instance_labels
[docs]@attrs(slots=True, order=False)
class _Shape(Annotation):
# Flattened list of point coordinates
points: List[float] = field(converter=lambda x: \
np.around(x, COORDINATE_ROUNDING_DIGITS).tolist())
label: Optional[int] = field(converter=attr.converters.optional(int),
default=None, kw_only=True)
z_order: int = field(default=0, validator=default_if_none(int),
kw_only=True)
[docs] def get_area(self):
raise NotImplementedError()
[docs] def get_bbox(self) -> Tuple[float, float, float, float]:
"Returns [x, y, w, h]"
points = self.points
if not points:
return None
xs = [p for p in points[0::2]]
ys = [p for p in points[1::2]]
x0 = min(xs)
x1 = max(xs)
y0 = min(ys)
y1 = max(ys)
return [x0, y0, x1 - x0, y1 - y0]
[docs]@attrs(slots=True, order=False)
class PolyLine(_Shape):
_type = AnnotationType.polyline
[docs] def as_polygon(self):
return self.points[:]
[docs] def get_area(self):
return 0
[docs]@attrs(slots=True, init=False, order=False)
class Cuboid3d(Annotation):
_type = AnnotationType.cuboid_3d
_points: List[float] = field(default=None)
label: Optional[int] = field(converter=attr.converters.optional(int),
default=None, kw_only=True)
@_points.validator
def _points_validator(self, attribute, points):
if points is None:
points = [0, 0, 0, 0, 0, 0, 1, 1, 1]
else:
assert len(points) == 3 + 3 + 3, points
points = np.around(points, COORDINATE_ROUNDING_DIGITS).tolist()
self._points = points
[docs] def __init__(self, position, rotation=None, scale=None, **kwargs):
assert len(position) == 3, position
if not rotation:
rotation = [0] * 3
if not scale:
scale = [1] * 3
kwargs.pop('points', None)
self.__attrs_init__(points=[*position, *rotation, *scale], **kwargs)
@property
def position(self):
"""[x, y, z]"""
return self._points[0:3]
@position.setter
def _set_poistion(self, value):
# TODO: fix the issue with separate coordinate rounding:
# self.position[0] = 12.345676
# - the number assigned won't be rounded.
self.position[:] = np.around(value, COORDINATE_ROUNDING_DIGITS).tolist()
@property
def rotation(self):
"""[rx, ry, rz]"""
return self._points[3:6]
@rotation.setter
def _set_rotation(self, value):
self.rotation[:] = np.around(value, COORDINATE_ROUNDING_DIGITS).tolist()
@property
def scale(self):
"""[sx, sy, sz]"""
return self._points[6:9]
@scale.setter
def _set_scale(self, value):
self.scale[:] = np.around(value, COORDINATE_ROUNDING_DIGITS).tolist()
[docs]@attrs(slots=True, order=False)
class Polygon(_Shape):
_type = AnnotationType.polygon
def __attrs_post_init__(self):
# keep the message on a single line to produce informative output
assert len(self.points) % 2 == 0 and 3 <= len(self.points) // 2, "Wrong polygon points: %s" % self.points
[docs] def get_area(self):
import pycocotools.mask as mask_utils
x, y, w, h = self.get_bbox()
rle = mask_utils.frPyObjects([self.points], y + h, x + w)
area = mask_utils.area(rle)[0]
return area
[docs]@attrs(slots=True, init=False, order=False)
class Bbox(_Shape):
_type = AnnotationType.bbox
[docs] def __init__(self, x, y, w, h, *args, **kwargs):
kwargs.pop('points', None) # comes from wrap()
self.__attrs_init__([x, y, x + w, y + h], *args, **kwargs)
@property
def x(self):
return self.points[0]
@property
def y(self):
return self.points[1]
@property
def w(self):
return self.points[2] - self.points[0]
@property
def h(self):
return self.points[3] - self.points[1]
[docs] def get_area(self):
return self.w * self.h
[docs] def get_bbox(self):
return [self.x, self.y, self.w, self.h]
[docs] def as_polygon(self):
x, y, w, h = self.get_bbox()
return [
x, y,
x + w, y,
x + w, y + h,
x, y + h
]
[docs] def iou(self, other: _Shape) -> Union[float, Literal[-1]]:
from datumaro.util.annotation_util import bbox_iou
return bbox_iou(self.get_bbox(), other.get_bbox())
[docs] def wrap(item, **kwargs):
d = {'x': item.x, 'y': item.y, 'w': item.w, 'h': item.h}
d.update(kwargs)
return attr.evolve(item, **d)
[docs]@attrs(slots=True, order=False)
class PointsCategories(Categories):
"""
Describes (key-)point metainfo such as point names and joints.
"""
[docs] @attrs(slots=True, order=False)
class Category:
# Names for specific points, e.g. eye, hose, mouth etc.
# These labels are not required to be in LabelCategories
labels: List[str] = field(
factory=list, validator=default_if_none(list))
# Pairs of connected point indices
joints: Set[Tuple[int, int]] = field(
factory=set, validator=default_if_none(set))
items: Dict[int, Category] = field(
factory=dict, validator=default_if_none(dict))
[docs] @classmethod
def from_iterable(cls, iterable: Union[
Tuple[int, List[str]],
Tuple[int, List[str], Set[Tuple[int, int]]],
]) -> PointsCategories:
"""
Create PointsCategories from an iterable.
Args:
iterable: An Iterable with the following elements:
- a label id
- a list of positional arguments for Categories
Returns:
PointsCategories: PointsCategories object
"""
temp_categories = cls()
for args in iterable:
temp_categories.add(*args)
return temp_categories
[docs] def add(self, label_id: int,
labels: Optional[Iterable[str]] = None,
joints: Iterable[Tuple[int, int]] = None):
if joints is None:
joints = []
joints = set(map(tuple, joints))
self.items[label_id] = self.Category(labels, joints)
[docs] def __contains__(self, idx: int) -> bool:
return idx in self.items
[docs] def __getitem__(self, idx: int) -> Category:
return self.items[idx]
[docs] def __len__(self) -> int:
return len(self.items)
[docs]@attrs(slots=True, order=False)
class Points(_Shape):
"""
Represents an ordered set of points.
"""
[docs] class Visibility(Enum):
absent = 0
hidden = 1
visible = 2
_type = AnnotationType.points
visibility: List[bool] = field(default=None)
@visibility.validator
def _visibility_validator(self, attribute, visibility):
if visibility is None:
visibility = [self.Visibility.visible] * (len(self.points) // 2)
else:
for i, v in enumerate(visibility):
if not isinstance(v, self.Visibility):
visibility[i] = self.Visibility(v)
assert len(visibility) == len(self.points) // 2
self.visibility = visibility
def __attrs_post_init__(self):
assert len(self.points) % 2 == 0, self.points
[docs] def get_area(self):
return 0
[docs] def get_bbox(self):
xs = [p for p, v in zip(self.points[0::2], self.visibility)
if v != __class__.Visibility.absent]
ys = [p for p, v in zip(self.points[1::2], self.visibility)
if v != __class__.Visibility.absent]
x0 = min(xs, default=0)
x1 = max(xs, default=0)
y0 = min(ys, default=0)
y1 = max(ys, default=0)
return [x0, y0, x1 - x0, y1 - y0]
[docs]@attrs(slots=True, order=False)
class Caption(Annotation):
"""
Represents arbitrary text annotations.
"""
_type = AnnotationType.caption
caption: str = field(converter=str)