Source code for datumaro.plugins.imagenet_txt_format
# Copyright (C) 2020-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT
from enum import Enum, auto
from typing import Iterable, Optional, Sequence, Tuple, Union
import os
import os.path as osp
from datumaro.components.annotation import (
AnnotationType, Label, LabelCategories,
)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.converter import Converter
from datumaro.components.errors import DatasetImportError
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file
class _LabelsSource(Enum):
file = auto()
generate = auto()
def _parse_annotation_line(line: str) -> Tuple[str, str, Sequence[int]]:
item = line.split('\"')
if 1 < len(item):
if len(item) == 3:
item_id = item[1]
item = item[2].split()
image = item_id + item[0]
label_ids = [int(id) for id in item[1:]]
else:
raise Exception("Line %s: unexpected number "
"of quotes in filename" % line)
else:
item = line.split()
item_id = osp.splitext(item[0])[0]
image = item[0]
label_ids = [int(id) for id in item[1:]]
return item_id, image, label_ids
[docs]class ImagenetTxtExtractor(SourceExtractor):
[docs] def __init__(self, path: str, *,
labels: Union[Iterable[str], str] = _LabelsSource.file.name,
labels_file: str = ImagenetTxtPath.LABELS_FILE,
image_dir: Optional[str] = None,
subset: Optional[str] = None,
):
assert osp.isfile(path), path
if not subset:
subset = osp.splitext(osp.basename(path))[0]
super().__init__(subset=subset)
root_dir = osp.dirname(path)
if not image_dir:
image_dir = ImagenetTxtPath.IMAGE_DIR
self.image_dir = osp.join(root_dir, image_dir)
self._generate_labels = False
if isinstance(labels, str):
labels_source = _LabelsSource[labels]
if labels_source == _LabelsSource.generate:
labels = ()
self._generate_labels = True
elif labels_source == _LabelsSource.file:
if has_meta_file(root_dir):
labels = parse_meta_file(root_dir).keys()
else:
labels = self._parse_labels(
osp.join(root_dir, labels_file))
else:
assert False, "Unhandled labels source %s" % labels_source
else:
assert all(isinstance(e, str) for e in labels)
self._categories = self._load_categories(labels)
self._items = list(self._load_items(path).values())
@staticmethod
def _parse_labels(path):
with open(path, encoding='utf-8') as labels_file:
return [s.strip() for s in labels_file]
def _load_categories(self, labels):
return { AnnotationType.label: LabelCategories.from_iterable(labels) }
def _load_items(self, path):
items = {}
with open(path, encoding='utf-8') as f:
for line in f:
item_id, image, label_ids = _parse_annotation_line(line)
anno = []
label_categories = self._categories[AnnotationType.label]
for label in label_ids:
if label < 0:
raise DatasetImportError(
f"Image '{item_id}': invalid label id '{label}'")
if len(label_categories) <= label:
if self._generate_labels:
while len(label_categories) <= label:
label_categories.add(f"class-{len(label_categories)}")
else:
raise DatasetImportError(
f"Image '{item_id}': unknown label id '{label}'")
anno.append(Label(label))
items[item_id] = DatasetItem(id=item_id, subset=self._subset,
image=osp.join(self.image_dir, image), annotations=anno)
return items
[docs]class ImagenetTxtImporter(Importer, CliPlugin):
[docs] @classmethod
def detect(cls, context: FormatDetectionContext) -> None:
annot_path = context.require_file('*.txt',
exclude_fnames=ImagenetTxtPath.LABELS_FILE)
with context.probe_text_file(
annot_path,
"must be an ImageNet-like annotation file",
) as f:
for line in f:
_, _, label_ids = _parse_annotation_line(line)
if label_ids: break
else:
# If there are no labels in the entire file, it's probably
# not actually an ImageNet file.
raise Exception
[docs] @classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('--labels',
choices=_LabelsSource.__members__,
default=_LabelsSource.file.name,
help="Where to get label descriptions from (use "
"'file' to load from the file specified by --labels-file; "
"'generate' to create generic ones)")
parser.add_argument('--labels-file',
default=ImagenetTxtPath.LABELS_FILE,
help="Path to the file with label descriptions (synsets.txt)")
return parser
[docs] @classmethod
def find_sources_with_params(cls, path, **extra_params):
if 'labels' not in extra_params \
or extra_params['labels'] == _LabelsSource.file.name:
labels_file_name = osp.basename(
extra_params.get('labels_file') or ImagenetTxtPath.LABELS_FILE)
def file_filter(p):
return osp.basename(p) != labels_file_name
else:
file_filter = None
return cls._find_sources_recursive(path, '.txt', 'imagenet_txt',
file_filter=file_filter)
[docs]class ImagenetTxtConverter(Converter):
DEFAULT_IMAGE_EXT = '.jpg'
[docs] def apply(self):
subset_dir = self._save_dir
os.makedirs(subset_dir, exist_ok=True)
extractor = self._extractor
for subset_name, subset in self._extractor.subsets().items():
annotation_file = osp.join(subset_dir, '%s.txt' % subset_name)
labels = {}
for item in subset:
item_id = item.id
if 1 < len(item_id.split()):
item_id = '\"' + item_id + '\"'
item_id += self._find_image_ext(item)
labels[item_id] = set(p.label for p in item.annotations
if p.type == AnnotationType.label)
if self._save_images and item.has_image:
self._save_image(item, subdir=ImagenetTxtPath.IMAGE_DIR)
annotation = ''
for item_id, item_labels in labels.items():
annotation += '%s %s\n' % (item_id,
' '.join(str(l) for l in item_labels))
with open(annotation_file, 'w', encoding='utf-8') as f:
f.write(annotation)
if self._save_dataset_meta:
self._save_meta_file(subset_dir)
else:
labels_file = osp.join(subset_dir, ImagenetTxtPath.LABELS_FILE)
with open(labels_file, 'w', encoding='utf-8') as f:
f.writelines(l.name + '\n'
for l in extractor.categories().get(
AnnotationType.label, LabelCategories())
)