Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| from collections import OrderedDict | |
| from typing import Dict, List, Optional, Sequence | |
| import numpy as np | |
| import torch | |
| from mmengine.dist import is_main_process | |
| from mmengine.evaluator import BaseMetric | |
| from mmengine.logging import MMLogger, print_log | |
| from mmengine.utils import mkdir_or_exist | |
| from PIL import Image | |
| from prettytable import PrettyTable | |
| from mmseg.registry import METRICS | |
| class IoUMetric(BaseMetric): | |
| """IoU evaluation metric. | |
| Args: | |
| ignore_index (int): Index that will be ignored in evaluation. | |
| Default: 255. | |
| iou_metrics (list[str] | str): Metrics to be calculated, the options | |
| includes 'mIoU', 'mDice' and 'mFscore'. | |
| nan_to_num (int, optional): If specified, NaN values will be replaced | |
| by the numbers defined by the user. Default: None. | |
| beta (int): Determines the weight of recall in the combined score. | |
| Default: 1. | |
| collect_device (str): Device name used for collecting results from | |
| different ranks during distributed training. Must be 'cpu' or | |
| 'gpu'. Defaults to 'cpu'. | |
| output_dir (str): The directory for output prediction. Defaults to | |
| None. | |
| format_only (bool): Only format result for results commit without | |
| perform evaluation. It is useful when you want to save the result | |
| to a specific format and submit it to the test server. | |
| Defaults to False. | |
| prefix (str, optional): The prefix that will be added in the metric | |
| names to disambiguate homonymous metrics of different evaluators. | |
| If prefix is not provided in the argument, self.default_prefix | |
| will be used instead. Defaults to None. | |
| """ | |
| def __init__(self, | |
| ignore_index: int = 255, | |
| iou_metrics: List[str] = ['mIoU'], | |
| nan_to_num: Optional[int] = None, | |
| beta: int = 1, | |
| collect_device: str = 'cpu', | |
| output_dir: Optional[str] = None, | |
| format_only: bool = False, | |
| prefix: Optional[str] = None, | |
| **kwargs) -> None: | |
| super().__init__(collect_device=collect_device, prefix=prefix) | |
| self.ignore_index = ignore_index | |
| self.metrics = iou_metrics | |
| self.nan_to_num = nan_to_num | |
| self.beta = beta | |
| self.output_dir = output_dir | |
| if self.output_dir and is_main_process(): | |
| mkdir_or_exist(self.output_dir) | |
| self.format_only = format_only | |
| def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: | |
| """Process one batch of data and data_samples. | |
| The processed results should be stored in ``self.results``, which will | |
| be used to compute the metrics when all batches have been processed. | |
| Args: | |
| data_batch (dict): A batch of data from the dataloader. | |
| data_samples (Sequence[dict]): A batch of outputs from the model. | |
| """ | |
| num_classes = len(self.dataset_meta['classes']) | |
| for data_sample in data_samples: | |
| pred_label = data_sample['pred_sem_seg']['data'].squeeze() | |
| # format_only always for test dataset without ground truth | |
| if not self.format_only: | |
| label = data_sample['gt_sem_seg']['data'].squeeze().to( | |
| pred_label) | |
| self.results.append( | |
| self.intersect_and_union(pred_label, label, num_classes, | |
| self.ignore_index)) | |
| # format_result | |
| if self.output_dir is not None: | |
| basename = osp.splitext(osp.basename( | |
| data_sample['img_path']))[0] | |
| png_filename = osp.abspath( | |
| osp.join(self.output_dir, f'{basename}.png')) | |
| output_mask = pred_label.cpu().numpy() | |
| # The index range of official ADE20k dataset is from 0 to 150. | |
| # But the index range of output is from 0 to 149. | |
| # That is because we set reduce_zero_label=True. | |
| if data_sample.get('reduce_zero_label', False): | |
| output_mask = output_mask + 1 | |
| output = Image.fromarray(output_mask.astype(np.uint8)) | |
| output.save(png_filename) | |
| def compute_metrics(self, results: list) -> Dict[str, float]: | |
| """Compute the metrics from processed results. | |
| Args: | |
| results (list): The processed results of each batch. | |
| Returns: | |
| Dict[str, float]: The computed metrics. The keys are the names of | |
| the metrics, and the values are corresponding results. The key | |
| mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, | |
| mRecall. | |
| """ | |
| logger: MMLogger = MMLogger.get_current_instance() | |
| if self.format_only: | |
| logger.info(f'results are saved to {osp.dirname(self.output_dir)}') | |
| return OrderedDict() | |
| # convert list of tuples to tuple of lists, e.g. | |
| # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to | |
| # ([A_1, ..., A_n], ..., [D_1, ..., D_n]) | |
| results = tuple(zip(*results)) | |
| assert len(results) == 4 | |
| total_area_intersect = sum(results[0]) | |
| total_area_union = sum(results[1]) | |
| total_area_pred_label = sum(results[2]) | |
| total_area_label = sum(results[3]) | |
| ret_metrics = self.total_area_to_metrics( | |
| total_area_intersect, total_area_union, total_area_pred_label, | |
| total_area_label, self.metrics, self.nan_to_num, self.beta) | |
| class_names = self.dataset_meta['classes'] | |
| # summary table | |
| ret_metrics_summary = OrderedDict({ | |
| ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) | |
| for ret_metric, ret_metric_value in ret_metrics.items() | |
| }) | |
| metrics = dict() | |
| for key, val in ret_metrics_summary.items(): | |
| if key == 'aAcc': | |
| metrics[key] = val | |
| else: | |
| metrics['m' + key] = val | |
| # each class table | |
| ret_metrics.pop('aAcc', None) | |
| ret_metrics_class = OrderedDict({ | |
| ret_metric: np.round(ret_metric_value * 100, 2) | |
| for ret_metric, ret_metric_value in ret_metrics.items() | |
| }) | |
| ret_metrics_class.update({'Class': class_names}) | |
| ret_metrics_class.move_to_end('Class', last=False) | |
| class_table_data = PrettyTable() | |
| for key, val in ret_metrics_class.items(): | |
| class_table_data.add_column(key, val) | |
| print_log('per class results:', logger) | |
| print_log('\n' + class_table_data.get_string(), logger=logger) | |
| return metrics | |
| def intersect_and_union(pred_label: torch.tensor, label: torch.tensor, | |
| num_classes: int, ignore_index: int): | |
| """Calculate Intersection and Union. | |
| Args: | |
| pred_label (torch.tensor): Prediction segmentation map | |
| or predict result filename. The shape is (H, W). | |
| label (torch.tensor): Ground truth segmentation map | |
| or label filename. The shape is (H, W). | |
| num_classes (int): Number of categories. | |
| ignore_index (int): Index that will be ignored in evaluation. | |
| Returns: | |
| torch.Tensor: The intersection of prediction and ground truth | |
| histogram on all classes. | |
| torch.Tensor: The union of prediction and ground truth histogram on | |
| all classes. | |
| torch.Tensor: The prediction histogram on all classes. | |
| torch.Tensor: The ground truth histogram on all classes. | |
| """ | |
| mask = (label != ignore_index) | |
| pred_label = pred_label[mask] | |
| label = label[mask] | |
| intersect = pred_label[pred_label == label] | |
| area_intersect = torch.histc( | |
| intersect.float(), bins=(num_classes), min=0, | |
| max=num_classes - 1).cpu() | |
| area_pred_label = torch.histc( | |
| pred_label.float(), bins=(num_classes), min=0, | |
| max=num_classes - 1).cpu() | |
| area_label = torch.histc( | |
| label.float(), bins=(num_classes), min=0, | |
| max=num_classes - 1).cpu() | |
| area_union = area_pred_label + area_label - area_intersect | |
| return area_intersect, area_union, area_pred_label, area_label | |
| def total_area_to_metrics(total_area_intersect: np.ndarray, | |
| total_area_union: np.ndarray, | |
| total_area_pred_label: np.ndarray, | |
| total_area_label: np.ndarray, | |
| metrics: List[str] = ['mIoU'], | |
| nan_to_num: Optional[int] = None, | |
| beta: int = 1): | |
| """Calculate evaluation metrics | |
| Args: | |
| total_area_intersect (np.ndarray): The intersection of prediction | |
| and ground truth histogram on all classes. | |
| total_area_union (np.ndarray): The union of prediction and ground | |
| truth histogram on all classes. | |
| total_area_pred_label (np.ndarray): The prediction histogram on | |
| all classes. | |
| total_area_label (np.ndarray): The ground truth histogram on | |
| all classes. | |
| metrics (List[str] | str): Metrics to be evaluated, 'mIoU' and | |
| 'mDice'. | |
| nan_to_num (int, optional): If specified, NaN values will be | |
| replaced by the numbers defined by the user. Default: None. | |
| beta (int): Determines the weight of recall in the combined score. | |
| Default: 1. | |
| Returns: | |
| Dict[str, np.ndarray]: per category evaluation metrics, | |
| shape (num_classes, ). | |
| """ | |
| def f_score(precision, recall, beta=1): | |
| """calculate the f-score value. | |
| Args: | |
| precision (float | torch.Tensor): The precision value. | |
| recall (float | torch.Tensor): The recall value. | |
| beta (int): Determines the weight of recall in the combined | |
| score. Default: 1. | |
| Returns: | |
| [torch.tensor]: The f-score value. | |
| """ | |
| score = (1 + beta**2) * (precision * recall) / ( | |
| (beta**2 * precision) + recall) | |
| return score | |
| if isinstance(metrics, str): | |
| metrics = [metrics] | |
| allowed_metrics = ['mIoU', 'mDice', 'mFscore'] | |
| if not set(metrics).issubset(set(allowed_metrics)): | |
| raise KeyError(f'metrics {metrics} is not supported') | |
| all_acc = total_area_intersect.sum() / total_area_label.sum() | |
| ret_metrics = OrderedDict({'aAcc': all_acc}) | |
| for metric in metrics: | |
| if metric == 'mIoU': | |
| iou = total_area_intersect / total_area_union | |
| acc = total_area_intersect / total_area_label | |
| ret_metrics['IoU'] = iou | |
| ret_metrics['Acc'] = acc | |
| elif metric == 'mDice': | |
| dice = 2 * total_area_intersect / ( | |
| total_area_pred_label + total_area_label) | |
| acc = total_area_intersect / total_area_label | |
| ret_metrics['Dice'] = dice | |
| ret_metrics['Acc'] = acc | |
| elif metric == 'mFscore': | |
| precision = total_area_intersect / total_area_pred_label | |
| recall = total_area_intersect / total_area_label | |
| f_value = torch.tensor([ | |
| f_score(x[0], x[1], beta) for x in zip(precision, recall) | |
| ]) | |
| ret_metrics['Fscore'] = f_value | |
| ret_metrics['Precision'] = precision | |
| ret_metrics['Recall'] = recall | |
| ret_metrics = { | |
| metric: value.numpy() | |
| for metric, value in ret_metrics.items() | |
| } | |
| if nan_to_num is not None: | |
| ret_metrics = OrderedDict({ | |
| metric: np.nan_to_num(metric_value, nan=nan_to_num) | |
| for metric, metric_value in ret_metrics.items() | |
| }) | |
| return ret_metrics | |