Source code for datumaro.plugins.tf_detection_api_format.extractor
# Copyright (C) 2019-2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
from collections import OrderedDict
import os.path as osp
import re
import numpy as np
from datumaro.components.annotation import (
AnnotationType, Bbox, LabelCategories, Mask,
)
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.media import ByteImage
from datumaro.util.image import decode_image, lazy_image
from datumaro.util.tf_util import import_tf as _import_tf
from .format import DetectionApiPath
tf = _import_tf()
[docs]class TfDetectionApiExtractor(SourceExtractor):
[docs] def __init__(self, path, subset=None):
assert osp.isfile(path), path
images_dir = ''
root_dir = osp.dirname(osp.abspath(path))
if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR:
root_dir = osp.dirname(root_dir)
images_dir = osp.join(root_dir, DetectionApiPath.IMAGES_DIR)
if not osp.isdir(images_dir):
images_dir = ''
if not subset:
subset = osp.splitext(osp.basename(path))[0]
super().__init__(subset=subset)
items, labels = self._parse_tfrecord_file(path, self._subset, images_dir)
self._items = items
self._categories = self._load_categories(labels)
@staticmethod
def _load_categories(labels):
label_categories = LabelCategories().from_iterable(
e[0] for e in sorted(labels.items(), key=lambda item: item[1])
)
return { AnnotationType.label: label_categories }
@classmethod
def _parse_labelmap(cls, text):
id_pattern = r'(?:id\s*:\s*(?P<id>\d+))'
name_pattern = r'(?:name\s*:\s*[\'\"](?P<name>.*?)[\'\"])'
entry_pattern = r'(\{(?:[\s\n]*(?:%(id)s|%(name)s)[\s\n]*){2}\})+' % \
{'id': id_pattern, 'name': name_pattern}
matches = re.finditer(entry_pattern, text)
labelmap = {}
for match in matches:
label_id = match.group('id')
label_name = match.group('name')
if label_id is not None and label_name is not None:
labelmap[label_name] = int(label_id)
return labelmap
@classmethod
def _parse_tfrecord_file(cls, filepath, subset, images_dir):
dataset = tf.data.TFRecordDataset(filepath)
features = {
'image/filename': tf.io.FixedLenFeature([], tf.string),
'image/source_id': tf.io.FixedLenFeature([], tf.string),
'image/height': tf.io.FixedLenFeature([], tf.int64),
'image/width': tf.io.FixedLenFeature([], tf.int64),
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'image/format': tf.io.FixedLenFeature([], tf.string),
# use varlen to avoid errors when this field is missing
'image/key/sha256': tf.io.VarLenFeature(tf.string),
# Object boxes and classes.
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/label': tf.io.VarLenFeature(tf.int64),
'image/object/class/text': tf.io.VarLenFeature(tf.string),
'image/object/mask': tf.io.VarLenFeature(tf.string),
}
dataset_labels = OrderedDict()
labelmap_path = osp.join(osp.dirname(filepath),
DetectionApiPath.LABELMAP_FILE)
if osp.exists(labelmap_path):
with open(labelmap_path, 'r', encoding='utf-8') as f:
labelmap_text = f.read()
dataset_labels.update({ label: id - 1
for label, id in cls._parse_labelmap(labelmap_text).items()
})
dataset_items = []
for record in dataset:
parsed_record = tf.io.parse_single_example(record, features)
frame_id = parsed_record['image/source_id'].numpy().decode('utf-8')
frame_filename = \
parsed_record['image/filename'].numpy().decode('utf-8')
frame_height = tf.cast(
parsed_record['image/height'], tf.int64).numpy().item()
frame_width = tf.cast(
parsed_record['image/width'], tf.int64).numpy().item()
frame_image = parsed_record['image/encoded'].numpy()
xmins = tf.sparse.to_dense(
parsed_record['image/object/bbox/xmin']).numpy()
ymins = tf.sparse.to_dense(
parsed_record['image/object/bbox/ymin']).numpy()
xmaxs = tf.sparse.to_dense(
parsed_record['image/object/bbox/xmax']).numpy()
ymaxs = tf.sparse.to_dense(
parsed_record['image/object/bbox/ymax']).numpy()
label_ids = tf.sparse.to_dense(
parsed_record['image/object/class/label']).numpy()
labels = tf.sparse.to_dense(
parsed_record['image/object/class/text'],
default_value=b'').numpy()
masks = tf.sparse.to_dense(
parsed_record['image/object/mask'],
default_value=b'').numpy()
for label, label_id in zip(labels, label_ids):
label = label.decode('utf-8')
if not label:
continue
if label_id <= 0:
continue
if label in dataset_labels:
continue
dataset_labels[label] = label_id - 1
item_id = osp.splitext(frame_filename)[0]
annotations = []
for shape_id, shape in enumerate(
np.dstack((labels, xmins, ymins, xmaxs, ymaxs))[0]):
label = shape[0].decode('utf-8')
mask = None
if len(masks) != 0:
mask = masks[shape_id]
if mask is not None:
if isinstance(mask, bytes):
mask = lazy_image(mask, decode_image)
annotations.append(Mask(image=mask,
label=dataset_labels.get(label)
))
else:
x = clamp(shape[1] * frame_width, 0, frame_width)
y = clamp(shape[2] * frame_height, 0, frame_height)
w = clamp(shape[3] * frame_width, 0, frame_width) - x
h = clamp(shape[4] * frame_height, 0, frame_height) - y
annotations.append(Bbox(x, y, w, h,
label=dataset_labels.get(label)
))
image_size = None
if frame_height and frame_width:
image_size = (frame_height, frame_width)
image_params = {}
if frame_image:
image_params['data'] = frame_image
if frame_filename:
image_params['path'] = osp.join(images_dir, frame_filename)
image = None
if image_params:
image = ByteImage(**image_params, size=image_size)
dataset_items.append(DatasetItem(id=item_id, subset=subset,
image=image, annotations=annotations,
attributes={'source_id': frame_id}))
return dataset_items, dataset_labels