Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| from typing import Dict, List, Optional, Sequence, Union | |
| from mmseg.registry import DATASETS | |
| from .basesegdataset import BaseSegDataset | |
| try: | |
| from dsdl.dataset import DSDLDataset | |
| except ImportError: | |
| DSDLDataset = None | |
| class DSDLSegDataset(BaseSegDataset): | |
| """Dataset for dsdl segmentation. | |
| Args: | |
| specific_key_path(dict): Path of specific key which can not | |
| be loaded by it's field name. | |
| pre_transform(dict): pre-transform functions before loading. | |
| used_labels(sequence): list of actual used classes in train steps, | |
| this must be subset of class domain. | |
| """ | |
| METAINFO = {} | |
| def __init__(self, | |
| specific_key_path: Dict = {}, | |
| pre_transform: Dict = {}, | |
| used_labels: Optional[Sequence] = None, | |
| **kwargs) -> None: | |
| if DSDLDataset is None: | |
| raise RuntimeError( | |
| 'Package dsdl is not installed. Please run "pip install dsdl".' | |
| ) | |
| self.used_labels = used_labels | |
| loc_config = dict(type='LocalFileReader', working_dir='') | |
| if kwargs.get('data_root'): | |
| kwargs['ann_file'] = os.path.join(kwargs['data_root'], | |
| kwargs['ann_file']) | |
| required_fields = ['Image', 'LabelMap'] | |
| self.dsdldataset = DSDLDataset( | |
| dsdl_yaml=kwargs['ann_file'], | |
| location_config=loc_config, | |
| required_fields=required_fields, | |
| specific_key_path=specific_key_path, | |
| transform=pre_transform, | |
| ) | |
| BaseSegDataset.__init__(self, **kwargs) | |
| def load_data_list(self) -> List[Dict]: | |
| """Load data info from a dsdl yaml file named as ``self.ann_file`` | |
| Returns: | |
| List[dict]: A list of data list. | |
| """ | |
| if self.used_labels: | |
| self._metainfo['classes'] = tuple(self.used_labels) | |
| self.label_map = self.get_label_map(self.used_labels) | |
| else: | |
| self._metainfo['classes'] = tuple(['background'] + | |
| self.dsdldataset.class_names) | |
| data_list = [] | |
| for i, data in enumerate(self.dsdldataset): | |
| datainfo = dict( | |
| img_path=os.path.join(self.data_prefix['img_path'], | |
| data['Image'][0].location), | |
| seg_map_path=os.path.join(self.data_prefix['seg_map_path'], | |
| data['LabelMap'][0].location), | |
| label_map=self.label_map, | |
| reduce_zero_label=self.reduce_zero_label, | |
| seg_fields=[], | |
| ) | |
| data_list.append(datainfo) | |
| return data_list | |
| def get_label_map(self, | |
| new_classes: Optional[Sequence] = None | |
| ) -> Union[Dict, None]: | |
| """Require label mapping. | |
| The ``label_map`` is a dictionary, its keys are the old label ids and | |
| its values are the new label ids, and is used for changing pixel | |
| labels in load_annotations. If and only if old classes in class_dom | |
| is not equal to new classes in args and nether of them is not | |
| None, `label_map` is not None. | |
| Args: | |
| new_classes (list, tuple, optional): The new classes name from | |
| metainfo. Default to None. | |
| Returns: | |
| dict, optional: The mapping from old classes to new classes. | |
| """ | |
| old_classes = ['background'] + self.dsdldataset.class_names | |
| if (new_classes is not None and old_classes is not None | |
| and list(new_classes) != list(old_classes)): | |
| label_map = {} | |
| if not set(new_classes).issubset(old_classes): | |
| raise ValueError( | |
| f'new classes {new_classes} is not a ' | |
| f'subset of classes {old_classes} in class_dom.') | |
| for i, c in enumerate(old_classes): | |
| if c not in new_classes: | |
| label_map[i] = 255 | |
| else: | |
| label_map[i] = new_classes.index(c) | |
| return label_map | |
| else: | |
| return None | |