Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from abc import abstractmethod | |
| from typing import Union | |
| import torch | |
| import torch.nn.functional as F | |
| from mmengine.structures import InstanceData | |
| from torch import Tensor | |
| from mmseg.registry import TASK_UTILS | |
| class BaseMatchCost: | |
| """Base match cost class. | |
| Args: | |
| weight (Union[float, int]): Cost weight. Defaults to 1. | |
| """ | |
| def __init__(self, weight: Union[float, int] = 1.) -> None: | |
| self.weight = weight | |
| def __call__(self, pred_instances: InstanceData, | |
| gt_instances: InstanceData, **kwargs) -> Tensor: | |
| """Compute match cost. | |
| Args: | |
| pred_instances (InstanceData): Instances of model predictions. | |
| It often includes "labels" and "scores". | |
| gt_instances (InstanceData): Ground truth of instance | |
| annotations. It usually includes "labels". | |
| Returns: | |
| Tensor: Match Cost matrix of shape (num_preds, num_gts). | |
| """ | |
| pass | |
| class ClassificationCost(BaseMatchCost): | |
| """ClsSoftmaxCost. | |
| Args: | |
| weight (Union[float, int]): Cost weight. Defaults to 1. | |
| Examples: | |
| >>> from mmseg.models.assigners import ClassificationCost | |
| >>> import torch | |
| >>> self = ClassificationCost() | |
| >>> cls_pred = torch.rand(4, 3) | |
| >>> gt_labels = torch.tensor([0, 1, 2]) | |
| >>> factor = torch.tensor([10, 8, 10, 8]) | |
| >>> self(cls_pred, gt_labels) | |
| tensor([[-0.3430, -0.3525, -0.3045], | |
| [-0.3077, -0.2931, -0.3992], | |
| [-0.3664, -0.3455, -0.2881], | |
| [-0.3343, -0.2701, -0.3956]]) | |
| """ | |
| def __init__(self, weight: Union[float, int] = 1) -> None: | |
| super().__init__(weight=weight) | |
| def __call__(self, pred_instances: InstanceData, | |
| gt_instances: InstanceData, **kwargs) -> Tensor: | |
| """Compute match cost. | |
| Args: | |
| pred_instances (InstanceData): "scores" inside is | |
| predicted classification logits, of shape | |
| (num_queries, num_class). | |
| gt_instances (InstanceData): "labels" inside should have | |
| shape (num_gt, ). | |
| Returns: | |
| Tensor: Match Cost matrix of shape (num_preds, num_gts). | |
| """ | |
| assert hasattr(pred_instances, 'scores'), \ | |
| "pred_instances must contain 'scores'" | |
| assert hasattr(gt_instances, 'labels'), \ | |
| "gt_instances must contain 'labels'" | |
| pred_scores = pred_instances.scores | |
| gt_labels = gt_instances.labels | |
| pred_scores = pred_scores.softmax(-1) | |
| cls_cost = -pred_scores[:, gt_labels] | |
| return cls_cost * self.weight | |
| class DiceCost(BaseMatchCost): | |
| """Cost of mask assignments based on dice losses. | |
| Args: | |
| pred_act (bool): Whether to apply sigmoid to mask_pred. | |
| Defaults to False. | |
| eps (float): Defaults to 1e-3. | |
| naive_dice (bool): If True, use the naive dice loss | |
| in which the power of the number in the denominator is | |
| the first power. If False, use the second power that | |
| is adopted by K-Net and SOLO. Defaults to True. | |
| weight (Union[float, int]): Cost weight. Defaults to 1. | |
| """ | |
| def __init__(self, | |
| pred_act: bool = False, | |
| eps: float = 1e-3, | |
| naive_dice: bool = True, | |
| weight: Union[float, int] = 1.) -> None: | |
| super().__init__(weight=weight) | |
| self.pred_act = pred_act | |
| self.eps = eps | |
| self.naive_dice = naive_dice | |
| def _binary_mask_dice_loss(self, mask_preds: Tensor, | |
| gt_masks: Tensor) -> Tensor: | |
| """ | |
| Args: | |
| mask_preds (Tensor): Mask prediction in shape (num_queries, *). | |
| gt_masks (Tensor): Ground truth in shape (num_gt, *) | |
| store 0 or 1, 0 for negative class and 1 for | |
| positive class. | |
| Returns: | |
| Tensor: Dice cost matrix in shape (num_queries, num_gt). | |
| """ | |
| mask_preds = mask_preds.flatten(1) | |
| gt_masks = gt_masks.flatten(1).float() | |
| numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks) | |
| if self.naive_dice: | |
| denominator = mask_preds.sum(-1)[:, None] + \ | |
| gt_masks.sum(-1)[None, :] | |
| else: | |
| denominator = mask_preds.pow(2).sum(1)[:, None] + \ | |
| gt_masks.pow(2).sum(1)[None, :] | |
| loss = 1 - (numerator + self.eps) / (denominator + self.eps) | |
| return loss | |
| def __call__(self, pred_instances: InstanceData, | |
| gt_instances: InstanceData, **kwargs) -> Tensor: | |
| """Compute match cost. | |
| Args: | |
| pred_instances (InstanceData): Predicted instances which | |
| must contain "masks". | |
| gt_instances (InstanceData): Ground truth which must contain | |
| "mask". | |
| Returns: | |
| Tensor: Match Cost matrix of shape (num_preds, num_gts). | |
| """ | |
| assert hasattr(pred_instances, 'masks'), \ | |
| "pred_instances must contain 'masks'" | |
| assert hasattr(gt_instances, 'masks'), \ | |
| "gt_instances must contain 'masks'" | |
| pred_masks = pred_instances.masks | |
| gt_masks = gt_instances.masks | |
| if self.pred_act: | |
| pred_masks = pred_masks.sigmoid() | |
| dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks) | |
| return dice_cost * self.weight | |
| class CrossEntropyLossCost(BaseMatchCost): | |
| """CrossEntropyLossCost. | |
| Args: | |
| use_sigmoid (bool): Whether the prediction uses sigmoid | |
| of softmax. Defaults to True. | |
| weight (Union[float, int]): Cost weight. Defaults to 1. | |
| """ | |
| def __init__(self, | |
| use_sigmoid: bool = True, | |
| weight: Union[float, int] = 1.) -> None: | |
| super().__init__(weight=weight) | |
| self.use_sigmoid = use_sigmoid | |
| def _binary_cross_entropy(self, cls_pred: Tensor, | |
| gt_labels: Tensor) -> Tensor: | |
| """ | |
| Args: | |
| cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or | |
| (num_queries, *). | |
| gt_labels (Tensor): The learning label of prediction with | |
| shape (num_gt, *). | |
| Returns: | |
| Tensor: Cross entropy cost matrix in shape (num_queries, num_gt). | |
| """ | |
| cls_pred = cls_pred.flatten(1).float() | |
| gt_labels = gt_labels.flatten(1).float() | |
| n = cls_pred.shape[1] | |
| pos = F.binary_cross_entropy_with_logits( | |
| cls_pred, torch.ones_like(cls_pred), reduction='none') | |
| neg = F.binary_cross_entropy_with_logits( | |
| cls_pred, torch.zeros_like(cls_pred), reduction='none') | |
| cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \ | |
| torch.einsum('nc,mc->nm', neg, 1 - gt_labels) | |
| cls_cost = cls_cost / n | |
| return cls_cost | |
| def __call__(self, pred_instances: InstanceData, | |
| gt_instances: InstanceData, **kwargs) -> Tensor: | |
| """Compute match cost. | |
| Args: | |
| pred_instances (:obj:`InstanceData`): Predicted instances which | |
| must contain ``masks``. | |
| gt_instances (:obj:`InstanceData`): Ground truth which must contain | |
| ``masks``. | |
| Returns: | |
| Tensor: Match Cost matrix of shape (num_preds, num_gts). | |
| """ | |
| assert hasattr(pred_instances, 'masks'), \ | |
| "pred_instances must contain 'masks'" | |
| assert hasattr(gt_instances, 'masks'), \ | |
| "gt_instances must contain 'masks'" | |
| pred_masks = pred_instances.masks | |
| gt_masks = gt_instances.masks | |
| if self.use_sigmoid: | |
| cls_cost = self._binary_cross_entropy(pred_masks, gt_masks) | |
| else: | |
| raise NotImplementedError | |
| return cls_cost * self.weight | |