Source code for datumaro.plugins.imagenet_txt_format

# Copyright (C) 2020-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

from enum import Enum, auto
from typing import Iterable, Optional, Sequence, Tuple, Union
import os
import os.path as osp

from datumaro.components.annotation import (
    AnnotationType, Label, LabelCategories,
)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.converter import Converter
from datumaro.components.errors import DatasetImportError
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file


[docs]class ImagenetTxtPath: LABELS_FILE = 'synsets.txt' IMAGE_DIR = 'images'
class _LabelsSource(Enum): file = auto() generate = auto() def _parse_annotation_line(line: str) -> Tuple[str, str, Sequence[int]]: item = line.split('\"') if 1 < len(item): if len(item) == 3: item_id = item[1] item = item[2].split() image = item_id + item[0] label_ids = [int(id) for id in item[1:]] else: raise Exception("Line %s: unexpected number " "of quotes in filename" % line) else: item = line.split() item_id = osp.splitext(item[0])[0] image = item[0] label_ids = [int(id) for id in item[1:]] return item_id, image, label_ids
[docs]class ImagenetTxtExtractor(SourceExtractor):
[docs] def __init__(self, path: str, *, labels: Union[Iterable[str], str] = _LabelsSource.file.name, labels_file: str = ImagenetTxtPath.LABELS_FILE, image_dir: Optional[str] = None, subset: Optional[str] = None, ): assert osp.isfile(path), path if not subset: subset = osp.splitext(osp.basename(path))[0] super().__init__(subset=subset) root_dir = osp.dirname(path) if not image_dir: image_dir = ImagenetTxtPath.IMAGE_DIR self.image_dir = osp.join(root_dir, image_dir) self._generate_labels = False if isinstance(labels, str): labels_source = _LabelsSource[labels] if labels_source == _LabelsSource.generate: labels = () self._generate_labels = True elif labels_source == _LabelsSource.file: if has_meta_file(root_dir): labels = parse_meta_file(root_dir).keys() else: labels = self._parse_labels( osp.join(root_dir, labels_file)) else: assert False, "Unhandled labels source %s" % labels_source else: assert all(isinstance(e, str) for e in labels) self._categories = self._load_categories(labels) self._items = list(self._load_items(path).values())
@staticmethod def _parse_labels(path): with open(path, encoding='utf-8') as labels_file: return [s.strip() for s in labels_file] def _load_categories(self, labels): return { AnnotationType.label: LabelCategories.from_iterable(labels) } def _load_items(self, path): items = {} with open(path, encoding='utf-8') as f: for line in f: item_id, image, label_ids = _parse_annotation_line(line) anno = [] label_categories = self._categories[AnnotationType.label] for label in label_ids: if label < 0: raise DatasetImportError( f"Image '{item_id}': invalid label id '{label}'") if len(label_categories) <= label: if self._generate_labels: while len(label_categories) <= label: label_categories.add(f"class-{len(label_categories)}") else: raise DatasetImportError( f"Image '{item_id}': unknown label id '{label}'") anno.append(Label(label)) items[item_id] = DatasetItem(id=item_id, subset=self._subset, image=osp.join(self.image_dir, image), annotations=anno) return items
[docs]class ImagenetTxtImporter(Importer, CliPlugin):
[docs] @classmethod def detect(cls, context: FormatDetectionContext) -> None: annot_path = context.require_file('*.txt', exclude_fnames=ImagenetTxtPath.LABELS_FILE) with context.probe_text_file( annot_path, "must be an ImageNet-like annotation file", ) as f: for line in f: _, _, label_ids = _parse_annotation_line(line) if label_ids: break else: # If there are no labels in the entire file, it's probably # not actually an ImageNet file. raise Exception
[docs] @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) parser.add_argument('--labels', choices=_LabelsSource.__members__, default=_LabelsSource.file.name, help="Where to get label descriptions from (use " "'file' to load from the file specified by --labels-file; " "'generate' to create generic ones)") parser.add_argument('--labels-file', default=ImagenetTxtPath.LABELS_FILE, help="Path to the file with label descriptions (synsets.txt)") return parser
[docs] @classmethod def find_sources_with_params(cls, path, **extra_params): if 'labels' not in extra_params \ or extra_params['labels'] == _LabelsSource.file.name: labels_file_name = osp.basename( extra_params.get('labels_file') or ImagenetTxtPath.LABELS_FILE) def file_filter(p): return osp.basename(p) != labels_file_name else: file_filter = None return cls._find_sources_recursive(path, '.txt', 'imagenet_txt', file_filter=file_filter)
[docs]class ImagenetTxtConverter(Converter): DEFAULT_IMAGE_EXT = '.jpg'
[docs] def apply(self): subset_dir = self._save_dir os.makedirs(subset_dir, exist_ok=True) extractor = self._extractor for subset_name, subset in self._extractor.subsets().items(): annotation_file = osp.join(subset_dir, '%s.txt' % subset_name) labels = {} for item in subset: item_id = item.id if 1 < len(item_id.split()): item_id = '\"' + item_id + '\"' item_id += self._find_image_ext(item) labels[item_id] = set(p.label for p in item.annotations if p.type == AnnotationType.label) if self._save_images and item.has_image: self._save_image(item, subdir=ImagenetTxtPath.IMAGE_DIR) annotation = '' for item_id, item_labels in labels.items(): annotation += '%s %s\n' % (item_id, ' '.join(str(l) for l in item_labels)) with open(annotation_file, 'w', encoding='utf-8') as f: f.write(annotation) if self._save_dataset_meta: self._save_meta_file(subset_dir) else: labels_file = osp.join(subset_dir, ImagenetTxtPath.LABELS_FILE) with open(labels_file, 'w', encoding='utf-8') as f: f.writelines(l.name + '\n' for l in extractor.categories().get( AnnotationType.label, LabelCategories()) )