Spaces:
Build error
Build error
Core
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __init__.py +0 -0
- app.py +102 -4
- assets/readmes/DATASET.md +42 -0
- assets/readmes/biomedparse_prediction_examples.png +0 -0
- assets/requirements/requirements_custom.txt +6 -0
- assets/scripts/eval.sh +21 -0
- assets/scripts/train.sh +41 -0
- biomedparse_working.txt +182 -0
- configs/biomed_seg_lang_v1.yaml +329 -0
- configs/biomedparse_inference.yaml +196 -0
- datasets/__init__.py +2 -0
- datasets/build.py +630 -0
- datasets/dataset_mappers/__init__.py +1 -0
- datasets/dataset_mappers/biomed_dataset_mapper.py +378 -0
- datasets/evaluation/__init__.py +8 -0
- datasets/evaluation/captioning_evaluation.py +129 -0
- datasets/evaluation/classification_evaluation.py +76 -0
- datasets/evaluation/grounding_evaluation.py +173 -0
- datasets/evaluation/instance_evaluation.py +107 -0
- datasets/evaluation/interactive_evaluation.py +122 -0
- datasets/evaluation/panoptic_evaluation.py +199 -0
- datasets/evaluation/retrieval_evaluation.py +260 -0
- datasets/evaluation/segmentation_evaluation.py +195 -0
- datasets/refer.py +371 -0
- datasets/registration/__init__.py +3 -0
- datasets/registration/register_biomed_datasets.py +123 -0
- datasets/semseg_loader.py +10 -0
- datasets/utils/refcoco2json.py +41 -0
- datasets/utils/refer.py +372 -0
- datasets/visual_sampler/__init__.py +12 -0
- datasets/visual_sampler/circle.py +106 -0
- datasets/visual_sampler/mask_generators.py +215 -0
- datasets/visual_sampler/point.py +74 -0
- datasets/visual_sampler/polygon.py +137 -0
- datasets/visual_sampler/sampler.py +77 -0
- datasets/visual_sampler/scribble.py +96 -0
- datasets/visual_sampler/simpleclick_sampler.py +252 -0
- docker/Dockerfile +32 -0
- docker/README.md +9 -0
- docker/data_env.sh +1 -0
- docker/docker_build.sh +1 -0
- docker/docker_run.sh +1 -0
- docker/setup_inside_docker.sh +10 -0
- entry.py +92 -0
- environment.yml +149 -0
- example_prediction.py +47 -0
- examples/144DME_as_F.jpeg +0 -0
- examples/C3_EndoCV2021_00462.jpg +0 -0
- examples/CT_lung_nodule.dcm +0 -0
- examples/LIDC-IDRI-0140_143_280_CT_lung.png +0 -0
__init__.py
ADDED
|
File without changes
|
app.py
CHANGED
|
@@ -1,7 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Imports standard
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
import gradio as gr
|
| 7 |
+
import os
|
| 8 |
|
| 9 |
+
# Imports Hugging Face
|
| 10 |
+
from huggingface_hub import hf_hub_download, login
|
| 11 |
+
from google.colab import userdata
|
| 12 |
|
| 13 |
+
# Imports locaux
|
| 14 |
+
from modeling.BaseModel import BaseModel
|
| 15 |
+
from modeling import build_model
|
| 16 |
+
from utilities.distributed import init_distributed
|
| 17 |
+
from utilities.arguments import load_opt_from_config_files
|
| 18 |
+
from utilities.constants import BIOMED_CLASSES
|
| 19 |
+
from inference_utils.inference import interactive_infer_image
|
| 20 |
+
from inference_utils.output_processing import check_mask_stats
|
| 21 |
+
from inference_utils.processing_utils import read_rgb, get_instances
|
| 22 |
+
|
| 23 |
+
def init_huggingface():
|
| 24 |
+
"""Initialise la connexion Hugging Face et télécharge le modèle."""
|
| 25 |
+
login(userdata.get('HF_TOKEN'))
|
| 26 |
+
return hf_hub_download(
|
| 27 |
+
repo_id="microsoft/BiomedParse",
|
| 28 |
+
filename="biomedparse_v1.pt",
|
| 29 |
+
local_dir="pretrained"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def setup_model():
|
| 33 |
+
"""Configure et retourne le modèle."""
|
| 34 |
+
|
| 35 |
+
opt = init_distributed(opt)
|
| 36 |
+
model = BaseModel(opt, build_model(opt)).from_pretrained('hf_hub:microsoft/BiomedParse').eval().cuda()
|
| 37 |
+
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
|
| 40 |
+
BIOMED_CLASSES + ["background"],
|
| 41 |
+
is_eval=True
|
| 42 |
+
)
|
| 43 |
+
return model
|
| 44 |
+
|
| 45 |
+
def process_image(image, prompts, model):
|
| 46 |
+
"""Traite l'image avec les prompts donnés."""
|
| 47 |
+
if isinstance(image, str):
|
| 48 |
+
image = Image.open(image)
|
| 49 |
+
else:
|
| 50 |
+
image = Image.fromarray(image)
|
| 51 |
+
|
| 52 |
+
prompts = [p.strip() for p in prompts.split(',')]
|
| 53 |
+
|
| 54 |
+
pred_masks = interactive_infer_image(model, image, prompts)
|
| 55 |
+
|
| 56 |
+
fig = plt.figure(figsize=(10, 5))
|
| 57 |
+
plt.subplot(1, len(pred_masks) + 1, 1)
|
| 58 |
+
plt.imshow(image)
|
| 59 |
+
plt.title('Image originale')
|
| 60 |
+
plt.axis('off')
|
| 61 |
+
|
| 62 |
+
for i, mask in enumerate(pred_masks):
|
| 63 |
+
plt.subplot(1, len(pred_masks) + 1, i+2)
|
| 64 |
+
plt.imshow(image)
|
| 65 |
+
plt.imshow(mask, alpha=0.5, cmap='Reds')
|
| 66 |
+
plt.title(prompts[i])
|
| 67 |
+
plt.axis('off')
|
| 68 |
+
|
| 69 |
+
return fig
|
| 70 |
+
|
| 71 |
+
def setup_gradio_interface(model):
|
| 72 |
+
"""Configure l'interface Gradio."""
|
| 73 |
+
return gr.Interface(
|
| 74 |
+
theme=gr.Theme.from_hub("allenai/gradio-theme"),
|
| 75 |
+
fn=lambda img, txt: process_image(img, txt, model),
|
| 76 |
+
inputs=[
|
| 77 |
+
gr.Image(type="numpy", label="Image médicale"),
|
| 78 |
+
gr.Textbox(
|
| 79 |
+
label="Prompts (séparés par des virgules)",
|
| 80 |
+
placeholder="edema, lesion, etc...",
|
| 81 |
+
elem_classes="white"
|
| 82 |
+
)
|
| 83 |
+
],
|
| 84 |
+
outputs=gr.Plot(),
|
| 85 |
+
title="Core IA - Traitement d'image medicale",
|
| 86 |
+
description="Chargez une image médicale et spécifiez les éléments à segmenter",
|
| 87 |
+
examples=[
|
| 88 |
+
["examples/144DME_as_F.jpeg", "Dans cette image donne moi l'œdème"],
|
| 89 |
+
["examples/ISIC_0015551.jpg", "Cherche une lésion"],
|
| 90 |
+
["examples/T0011.jpg", "disque optique, cupule optique"],
|
| 91 |
+
["examples/C3_EndoCV2021_00462.jpg", "Trouve moi le polyp"],
|
| 92 |
+
["examples/covid_1585.png", "Qu'est ce qui ne va pas ici ?"],
|
| 93 |
+
['examples/Part_1_516_pathology_breast.png', "cellules néoplasiques , cellules inflammatoires , cellules du tissu conjonctif"]
|
| 94 |
+
]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def main():
|
| 98 |
+
"""Point d'entrée principal de l'application."""
|
| 99 |
+
init_huggingface()
|
| 100 |
+
model = setup_model()
|
| 101 |
+
interface = setup_gradio_interface(model)
|
| 102 |
+
interface.launch(debug=True)
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
main()
|
assets/readmes/DATASET.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# **BiomedParseData**
|
| 2 |
+
|
| 3 |
+
BiomedParseData was created from preprocessing publicly available biomedical image segmentation datasets.
|
| 4 |
+
|
| 5 |
+
These datasets are provided pre-formatted for convenience. For additional information about the datasets or their licenses, please reach out to the owners:
|
| 6 |
+
| Dataset | URL |
|
| 7 |
+
|---------------------------------------|-----|
|
| 8 |
+
| amos22 | [https://amos22.grand-challenge.org/](https://amos22.grand-challenge.org/) |
|
| 9 |
+
| MSD (Medical Segmentation Decathlon) | [http://medicaldecathlon.com/](http://medicaldecathlon.com/) |
|
| 10 |
+
| KiTS23 | [https://github.com/neheller/kits23](https://github.com/neheller/kits23) |
|
| 11 |
+
| BTCV | [https://www.synapse.org/#!Synapse:syn3193805/wiki/217790](https://www.synapse.org/#!Synapse:syn3193805/wiki/217790) |
|
| 12 |
+
| COVID-19 CT | [https://www.kaggle.com/datasets/andrewmvd/covid19-ct-scans](https://www.kaggle.com/datasets/andrewmvd/covid19-ct-scans) |
|
| 13 |
+
| LIDR-IDRI | [https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI) |
|
| 14 |
+
| ACDC | [https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html](https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html) |
|
| 15 |
+
| M&Ms | [https://www.ub.edu/mnms/](https://www.ub.edu/mnms/) |
|
| 16 |
+
| PROMISE12 | [cite https://doi.org/10.1016/j.media.2013.12.002](https://doi.org/10.1016/j.media.2013.12.002) |
|
| 17 |
+
| LGG | [https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation](https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation) |
|
| 18 |
+
| COVID-QU-Ex | [https://www.kaggle.com/datasets/anasmohammedtahir/covidqu](https://www.kaggle.com/datasets/anasmohammedtahir/covidqu) |
|
| 19 |
+
| QaTa-COV19 | [https://www.kaggle.com/datasets/aysendegerli/qatacov19-dataset](https://www.kaggle.com/datasets/aysendegerli/qatacov19-dataset) |
|
| 20 |
+
| SIIM-ACR Pneumothorax Segmentation | [https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks](https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks) |
|
| 21 |
+
| Chest Xray Masks and Labels Dataset | [https://datasetninja.com/chest-xray](https://datasetninja.com/chest-xray) |
|
| 22 |
+
| COVID-19 Radiography Database | [https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database](https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database) |
|
| 23 |
+
| CAMUS | [https://www.creatis.insa-lyon.fr/Challenge/camus/index.html](https://www.creatis.insa-lyon.fr/Challenge/camus/index.html) |
|
| 24 |
+
| BUSI | [https://scholar.cu.edu.eg/?q=afahmy/pages/dataset](https://scholar.cu.edu.eg/?q=afahmy/pages/dataset) |
|
| 25 |
+
| FH-PS-AOP | [https://zenodo.org/records/7851339#.ZEH6eHZBztU](https://zenodo.org/records/7851339#.ZEH6eHZBztU) |
|
| 26 |
+
| CDD-CESM | [https://www.cancerimagingarchive.net/collection/cdd-cesm/](https://www.cancerimagingarchive.net/collection/cdd-cesm/) |
|
| 27 |
+
| PolypGen | [https://www.synapse.org/#!Synapse:syn26376615/wiki/613312](https://www.synapse.org/#!Synapse:syn26376615/wiki/613312) |
|
| 28 |
+
| NeoPolyp | [https://www.kaggle.com/c/bkai-igh-neopolyp/data](https://www.kaggle.com/c/bkai-igh-neopolyp/data) |
|
| 29 |
+
| ISIC 2018 | [https://challenge2018.isic-archive.com/task1/](https://challenge2018.isic-archive.com/task1/) |
|
| 30 |
+
| UwaterlooSkinCancer | [Skin Cancer Detection \| Vision and Image Processing Lab \| University of Waterloo](https://uwaterloo.ca) |
|
| 31 |
+
| OCT-CME | [https://www.kaggle.com/datasets/zeeshanahmed13/intraretinal-cystoid-fluid](https://www.kaggle.com/datasets/zeeshanahmed13/intraretinal-cystoid-fluid) |
|
| 32 |
+
| REFUGE | [https://bitbucket.org/woalsdnd/refuge/src](https://bitbucket.org/woalsdnd/refuge/src) |
|
| 33 |
+
| G1020 | [https://www.dfki.uni-kl.de/g1020](https://www.dfki.uni-kl.de/g1020) |
|
| 34 |
+
| DRIVE | [https://drive.grand-challenge.org/](https://drive.grand-challenge.org/) |
|
| 35 |
+
| GlaS | [https://warwick.ac.uk/fac/cross_fac/tia/data/glascontest/](https://warwick.ac.uk/fac/cross_fac/tia/data/glascontest/) |
|
| 36 |
+
| PanNuke | [https://jgamper.github.io/PanNukeDataset/](https://jgamper.github.io/PanNukeDataset/) |
|
| 37 |
+
| FUMPE | [https://figshare.com/collections/FUMPE/4107803/1](https://figshare.com/collections/FUMPE/4107803/1) |
|
| 38 |
+
| TotalSegmentator | [https://github.com/wasserth/TotalSegmentator](https://github.com/wasserth/TotalSegmentator) |
|
| 39 |
+
| BraTS2023 | [https://www.synapse.org/#!Synapse:syn51156910/wiki/621282](https://www.synapse.org/#!Synapse:syn51156910/wiki/621282) |
|
| 40 |
+
| AbdomenCT-1K | [https://github.com/JunMa11/AbdomenCT-1K](https://github.com/JunMa11/AbdomenCT-1K) |
|
| 41 |
+
| US Simulation & Segmentation | [https://www.kaggle.com/datasets/ignaciorlando/ussimandsegm](https://www.kaggle.com/datasets/ignaciorlando/ussimandsegm) |
|
| 42 |
+
| CDD-CESM | [https://www.cancerimagingarchive.net/collection/cdd-cesm/](https://www.cancerimagingarchive.net/collection/cdd-cesm/) |
|
assets/readmes/biomedparse_prediction_examples.png
ADDED
|
assets/requirements/requirements_custom.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git+https://github.com/cocodataset/panopticapi.git
|
| 2 |
+
git+https://github.com/openai/CLIP.git
|
| 3 |
+
#git+https://github.com/arogozhnikov/einops.git
|
| 4 |
+
#git+https://github.com/facebookresearch/detectron2.git
|
| 5 |
+
git+https://github.com/MaureenZOU/detectron2-xyz.git
|
| 6 |
+
#git+https://github.com/openai/whisper.git
|
assets/scripts/eval.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export DETECTRON2_DATASETS=biomedparse_datasets/
|
| 2 |
+
export DATASET=biomedparse_datasets/
|
| 3 |
+
export DATASET2=biomedparse_datasets/
|
| 4 |
+
export VLDATASET=biomedparse_datasets/
|
| 5 |
+
export PATH=$PATH:biomedparse_datasets/coco_caption/jre1.8.0_321/bin/
|
| 6 |
+
export PYTHONPATH=$PYTHONPATH:biomedparse_datasets/coco_caption/
|
| 7 |
+
export OMPI_ALLOW_RUN_AS_ROOT=1
|
| 8 |
+
export OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1
|
| 9 |
+
#export WANDB_KEY=YOUR_WANDB_KEY # Provide your wandb key here
|
| 10 |
+
CUDA_VISIBLE_DEVICES=0 mpirun -n 1 python entry.py evaluate \
|
| 11 |
+
--conf_files configs/biomed_seg_lang_v1.yaml \
|
| 12 |
+
--overrides \
|
| 13 |
+
MODEL.DECODER.HIDDEN_DIM 512 \
|
| 14 |
+
MODEL.ENCODER.CONVS_DIM 512 \
|
| 15 |
+
MODEL.ENCODER.MASK_DIM 512 \
|
| 16 |
+
TEST.BATCH_SIZE_TOTAL 1 \
|
| 17 |
+
FP16 True \
|
| 18 |
+
WEIGHT True \
|
| 19 |
+
STANDARD_TEXT_FOR_EVAL False \
|
| 20 |
+
RESUME_FROM pretrained/biomedparse_v1.pt \
|
| 21 |
+
|
assets/scripts/train.sh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export DETECTRON2_DATASETS=biomedparse_datasets/
|
| 2 |
+
export DATASET=biomedparse_datasets/
|
| 3 |
+
export DATASET2=biomedparse_datasets/
|
| 4 |
+
export VLDATASET=biomedparse_datasets/
|
| 5 |
+
export PATH=$PATH:biomedparse_datasets/coco_caption/jre1.8.0_321/bin/
|
| 6 |
+
export PYTHONPATH=$PYTHONPATH:biomedparse_datasets/coco_caption/
|
| 7 |
+
export OMPI_ALLOW_RUN_AS_ROOT=1
|
| 8 |
+
export OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1
|
| 9 |
+
#export WANDB_KEY=YOUR_WANDB_KEY # Provide your wandb key here
|
| 10 |
+
CUDA_VISIBLE_DEVICES=0 mpirun -n 1 python entry.py train \
|
| 11 |
+
--conf_files configs/biomed_seg_lang_v1.yaml \
|
| 12 |
+
--overrides \
|
| 13 |
+
FP16 True \
|
| 14 |
+
RANDOM_SEED 2024 \
|
| 15 |
+
BioMed.INPUT.IMAGE_SIZE 1024 \
|
| 16 |
+
MODEL.DECODER.HIDDEN_DIM 512 \
|
| 17 |
+
MODEL.ENCODER.CONVS_DIM 512 \
|
| 18 |
+
MODEL.ENCODER.MASK_DIM 512 \
|
| 19 |
+
TEST.BATCH_SIZE_TOTAL 4 \
|
| 20 |
+
TRAIN.BATCH_SIZE_TOTAL 4 \
|
| 21 |
+
TRAIN.BATCH_SIZE_PER_GPU 4 \
|
| 22 |
+
SOLVER.MAX_NUM_EPOCHS 20 \
|
| 23 |
+
SOLVER.BASE_LR 0.00001 \
|
| 24 |
+
SOLVER.FIX_PARAM.backbone False \
|
| 25 |
+
SOLVER.FIX_PARAM.lang_encoder False \
|
| 26 |
+
SOLVER.FIX_PARAM.pixel_decoder False \
|
| 27 |
+
MODEL.DECODER.COST_SPATIAL.CLASS_WEIGHT 1.0 \
|
| 28 |
+
MODEL.DECODER.COST_SPATIAL.MASK_WEIGHT 1.0 \
|
| 29 |
+
MODEL.DECODER.COST_SPATIAL.DICE_WEIGHT 1.0 \
|
| 30 |
+
MODEL.DECODER.TOP_SPATIAL_LAYERS 10 \
|
| 31 |
+
MODEL.DECODER.SPATIAL.ENABLED True \
|
| 32 |
+
MODEL.DECODER.GROUNDING.ENABLED True \
|
| 33 |
+
LOADER.SAMPLE_PROB prop \
|
| 34 |
+
BioMed.INPUT.RANDOM_ROTATE True \
|
| 35 |
+
FIND_UNUSED_PARAMETERS True \
|
| 36 |
+
ATTENTION_ARCH.SPATIAL_MEMORIES 32 \
|
| 37 |
+
MODEL.DECODER.SPATIAL.MAX_ITER 0 \
|
| 38 |
+
ATTENTION_ARCH.QUERY_NUMBER 3 \
|
| 39 |
+
STROKE_SAMPLER.MAX_CANDIDATE 10 \
|
| 40 |
+
WEIGHT True \
|
| 41 |
+
RESUME_FROM pretrained/biomedparse_v1.pt
|
biomedparse_working.txt
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: biomedparse
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- _libgcc_mutex=0.1=main
|
| 8 |
+
- _openmp_mutex=5.1=1_gnu
|
| 9 |
+
- blas=1.0=mkl
|
| 10 |
+
- brotli-python=1.0.9=py312h6a678d5_8
|
| 11 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 12 |
+
- ca-certificates=2024.7.2=h06a4308_0
|
| 13 |
+
- certifi=2024.7.4=py312h06a4308_0
|
| 14 |
+
- charset-normalizer=3.3.2=pyhd3eb1b0_0
|
| 15 |
+
- cuda-cudart=12.4.127=0
|
| 16 |
+
- cuda-cupti=12.4.127=0
|
| 17 |
+
- cuda-libraries=12.4.0=0
|
| 18 |
+
- cuda-nvrtc=12.4.127=0
|
| 19 |
+
- cuda-nvtx=12.4.127=0
|
| 20 |
+
- cuda-opencl=12.6.68=0
|
| 21 |
+
- cuda-runtime=12.4.0=0
|
| 22 |
+
- cuda-version=12.6=3
|
| 23 |
+
- expat=2.6.2=h6a678d5_0
|
| 24 |
+
- ffmpeg=4.3=hf484d3e_0
|
| 25 |
+
- filelock=3.13.1=py312h06a4308_0
|
| 26 |
+
- freetype=2.12.1=h4a9f257_0
|
| 27 |
+
- gmp=6.2.1=h295c915_3
|
| 28 |
+
- gnutls=3.6.15=he1e5248_0
|
| 29 |
+
- idna=3.7=py312h06a4308_0
|
| 30 |
+
- intel-openmp=2023.1.0=hdb19cb5_46306
|
| 31 |
+
- jinja2=3.1.4=py312h06a4308_0
|
| 32 |
+
- jpeg=9e=h5eee18b_3
|
| 33 |
+
- lame=3.100=h7b6447c_0
|
| 34 |
+
- lcms2=2.12=h3be6417_0
|
| 35 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
| 36 |
+
- lerc=3.0=h295c915_0
|
| 37 |
+
- libcublas=12.4.2.65=0
|
| 38 |
+
- libcufft=11.2.0.44=0
|
| 39 |
+
- libcufile=1.11.1.6=0
|
| 40 |
+
- libcurand=10.3.7.68=0
|
| 41 |
+
- libcusolver=11.6.0.99=0
|
| 42 |
+
- libcusparse=12.3.0.142=0
|
| 43 |
+
- libdeflate=1.17=h5eee18b_1
|
| 44 |
+
- libffi=3.4.4=h6a678d5_1
|
| 45 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 46 |
+
- libgomp=11.2.0=h1234567_1
|
| 47 |
+
- libiconv=1.16=h5eee18b_3
|
| 48 |
+
- libidn2=2.3.4=h5eee18b_0
|
| 49 |
+
- libjpeg-turbo=2.0.0=h9bf148f_0
|
| 50 |
+
- libnpp=12.2.5.2=0
|
| 51 |
+
- libnvfatbin=12.6.68=0
|
| 52 |
+
- libnvjitlink=12.4.99=0
|
| 53 |
+
- libnvjpeg=12.3.1.89=0
|
| 54 |
+
- libpng=1.6.39=h5eee18b_0
|
| 55 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 56 |
+
- libtasn1=4.19.0=h5eee18b_0
|
| 57 |
+
- libtiff=4.5.1=h6a678d5_0
|
| 58 |
+
- libunistring=0.9.10=h27cfd23_0
|
| 59 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 60 |
+
- libwebp-base=1.3.2=h5eee18b_0
|
| 61 |
+
- llvm-openmp=14.0.6=h9e868ea_0
|
| 62 |
+
- lz4-c=1.9.4=h6a678d5_1
|
| 63 |
+
- markupsafe=2.1.3=py312h5eee18b_0
|
| 64 |
+
- mkl=2023.1.0=h213fc3f_46344
|
| 65 |
+
- mkl-service=2.4.0=py312h5eee18b_1
|
| 66 |
+
- mkl_fft=1.3.10=py312h5eee18b_0
|
| 67 |
+
- mkl_random=1.2.7=py312h526ad5a_0
|
| 68 |
+
- mpmath=1.3.0=py312h06a4308_0
|
| 69 |
+
- ncurses=6.4=h6a678d5_0
|
| 70 |
+
- nettle=3.7.3=hbbd107a_1
|
| 71 |
+
- networkx=3.2.1=py312h06a4308_0
|
| 72 |
+
- openh264=2.1.1=h4ff587b_0
|
| 73 |
+
- openjpeg=2.5.2=he7f1fd0_0
|
| 74 |
+
- openssl=3.0.14=h5eee18b_0
|
| 75 |
+
- pip=24.2=py312h06a4308_0
|
| 76 |
+
- pysocks=1.7.1=py312h06a4308_0
|
| 77 |
+
- python=3.12.4=h5148396_1
|
| 78 |
+
- pytorch=2.4.1=py3.12_cuda12.4_cudnn9.1.0_0
|
| 79 |
+
- pytorch-cuda=12.4=hc786d27_6
|
| 80 |
+
- pytorch-mutex=1.0=cuda
|
| 81 |
+
- pyyaml=6.0.1=py312h5eee18b_0
|
| 82 |
+
- readline=8.2=h5eee18b_0
|
| 83 |
+
- requests=2.32.3=py312h06a4308_0
|
| 84 |
+
- setuptools=72.1.0=py312h06a4308_0
|
| 85 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 86 |
+
- sympy=1.13.2=py312h06a4308_0
|
| 87 |
+
- tbb=2021.8.0=hdb19cb5_0
|
| 88 |
+
- tk=8.6.14=h39e8969_0
|
| 89 |
+
- torchaudio=2.4.1=py312_cu124
|
| 90 |
+
- torchtriton=3.0.0=py312
|
| 91 |
+
- torchvision=0.19.1=py312_cu124
|
| 92 |
+
- typing_extensions=4.11.0=py312h06a4308_0
|
| 93 |
+
- urllib3=2.2.2=py312h06a4308_0
|
| 94 |
+
- wheel=0.43.0=py312h06a4308_0
|
| 95 |
+
- xz=5.4.6=h5eee18b_1
|
| 96 |
+
- yaml=0.2.5=h7b6447c_0
|
| 97 |
+
- zlib=1.2.13=h5eee18b_1
|
| 98 |
+
- zstd=1.5.5=hc292b87_2
|
| 99 |
+
- pip:
|
| 100 |
+
- absl-py==2.1.0
|
| 101 |
+
- accelerate==0.23.0
|
| 102 |
+
- antlr4-python3-runtime==4.9.3
|
| 103 |
+
- appdirs==1.4.4
|
| 104 |
+
- black==21.4b2
|
| 105 |
+
- click==8.1.7
|
| 106 |
+
- clip==1.0
|
| 107 |
+
- cloudpickle==3.0.0
|
| 108 |
+
- contourpy==1.3.0
|
| 109 |
+
- cycler==0.12.1
|
| 110 |
+
- cython==3.0.2
|
| 111 |
+
- deepspeed==0.10.3
|
| 112 |
+
- detectron2==0.6
|
| 113 |
+
- diffdist==0.1
|
| 114 |
+
- einops==0.7.0
|
| 115 |
+
- fonttools==4.53.1
|
| 116 |
+
- fsspec==2024.9.0
|
| 117 |
+
- ftfy==6.1.1
|
| 118 |
+
- future==1.0.0
|
| 119 |
+
- fvcore==0.1.5.post20221221
|
| 120 |
+
- grpcio==1.66.1
|
| 121 |
+
- hjson==3.1.0
|
| 122 |
+
- huggingface-hub==0.17.3
|
| 123 |
+
- hydra-core==1.3.2
|
| 124 |
+
- imageio==2.35.1
|
| 125 |
+
- infinibatch==0.1.1
|
| 126 |
+
- iopath==0.1.9
|
| 127 |
+
- joblib==1.4.2
|
| 128 |
+
- json-tricks==3.17.3
|
| 129 |
+
- kiwisolver==1.4.7
|
| 130 |
+
- kornia==0.7.0
|
| 131 |
+
- lazy-loader==0.4
|
| 132 |
+
- markdown==3.7
|
| 133 |
+
- matplotlib==3.9.2
|
| 134 |
+
- mup==1.0.0
|
| 135 |
+
- mypy-extensions==1.0.0
|
| 136 |
+
- ninja==1.11.1.1
|
| 137 |
+
- nltk==3.8.1
|
| 138 |
+
- numpy==1.26.4
|
| 139 |
+
- omegaconf==2.3.0
|
| 140 |
+
- opencv-python==4.8.1.78
|
| 141 |
+
- packaging==24.1
|
| 142 |
+
- pandas==2.0.3
|
| 143 |
+
- panopticapi==0.1
|
| 144 |
+
- pathspec==0.12.1
|
| 145 |
+
- pillow==9.4.0
|
| 146 |
+
- portalocker==2.10.1
|
| 147 |
+
- protobuf==5.28.0
|
| 148 |
+
- psutil==6.0.0
|
| 149 |
+
- py-cpuinfo==9.0.0
|
| 150 |
+
- pycocotools==2.0.7
|
| 151 |
+
- pydantic==1.10.18
|
| 152 |
+
- pydot==3.0.1
|
| 153 |
+
- pyparsing==3.1.4
|
| 154 |
+
- python-dateutil==2.9.0.post0
|
| 155 |
+
- pytz==2024.1
|
| 156 |
+
- pywavelets==1.7.0
|
| 157 |
+
- regex==2023.10.3
|
| 158 |
+
- safetensors==0.4.4
|
| 159 |
+
- scikit-image==0.21.0
|
| 160 |
+
- scikit-learn==1.3.1
|
| 161 |
+
- scipy==1.14.1
|
| 162 |
+
- seaborn==0.13.2
|
| 163 |
+
- sentencepiece==0.1.99
|
| 164 |
+
- six==1.16.0
|
| 165 |
+
- tabulate==0.9.0
|
| 166 |
+
- tenacity==9.0.0
|
| 167 |
+
- tensorboard==2.17.1
|
| 168 |
+
- tensorboard-data-server==0.7.2
|
| 169 |
+
- termcolor==2.4.0
|
| 170 |
+
- threadpoolctl==3.5.0
|
| 171 |
+
- tifffile==2024.8.30
|
| 172 |
+
- timm==0.4.12
|
| 173 |
+
- tokenizers==0.14.1
|
| 174 |
+
- toml==0.10.2
|
| 175 |
+
- tqdm==4.66.5
|
| 176 |
+
- transformers==4.34.0
|
| 177 |
+
- tzdata==2024.1
|
| 178 |
+
- vision-datasets==0.2.2
|
| 179 |
+
- wcwidth==0.2.13
|
| 180 |
+
- werkzeug==3.0.4
|
| 181 |
+
- yacs==0.1.8
|
| 182 |
+
prefix: /anaconda/envs/biomedparse
|
configs/biomed_seg_lang_v1.yaml
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
| 3 |
+
# Copyright (c) 2022 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Written by Xueyan Zou ([email protected])
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
# Define Test/Trainer/Saving
|
| 9 |
+
PIPELINE: XDecoderPipeline
|
| 10 |
+
TRAINER: xdecoder
|
| 11 |
+
SAVE_DIR: './output'
|
| 12 |
+
base_path: "./"
|
| 13 |
+
|
| 14 |
+
# Resume Logistic
|
| 15 |
+
RESUME: false
|
| 16 |
+
WEIGHT: false
|
| 17 |
+
RESUME_FROM: ''
|
| 18 |
+
EVAL_AT_START: false
|
| 19 |
+
SAVE_CHECKPOINT: True
|
| 20 |
+
|
| 21 |
+
# Logging and Debug
|
| 22 |
+
WANDB: False
|
| 23 |
+
LOG_EVERY: 100
|
| 24 |
+
FIND_UNUSED_PARAMETERS: false
|
| 25 |
+
|
| 26 |
+
# Speed up training
|
| 27 |
+
FP16: false
|
| 28 |
+
PORT: '36873'
|
| 29 |
+
|
| 30 |
+
# misc
|
| 31 |
+
LOADER:
|
| 32 |
+
JOINT: True
|
| 33 |
+
KEY_DATASET: ""
|
| 34 |
+
SAMPLE_PROB: "prop" # sampling probability proportional to data size. Use "equal" for each bach from all datasets
|
| 35 |
+
MIXING_LEVEL: 1 # num of different datasets for batch mixing on each GPU
|
| 36 |
+
|
| 37 |
+
RANDOM_SEED: 2024
|
| 38 |
+
|
| 39 |
+
STANDARD_TEXT_FOR_EVAL: False
|
| 40 |
+
|
| 41 |
+
##################
|
| 42 |
+
# Task settings
|
| 43 |
+
##################
|
| 44 |
+
VERBOSE: true
|
| 45 |
+
MODEL:
|
| 46 |
+
NAME: seem_model_v1
|
| 47 |
+
HEAD: xdecoder_head
|
| 48 |
+
MASK_ON: false
|
| 49 |
+
KEYPOINT_ON: false
|
| 50 |
+
LOAD_PROPOSALS: false
|
| 51 |
+
DIM_PROJ: 512
|
| 52 |
+
TEXT:
|
| 53 |
+
ARCH: vlpencoder
|
| 54 |
+
NAME: transformer
|
| 55 |
+
TOKENIZER: clip
|
| 56 |
+
CONTEXT_LENGTH: 77 #256 # 77
|
| 57 |
+
WIDTH: 512 # 768 # 512
|
| 58 |
+
HEADS: 8
|
| 59 |
+
LAYERS: 12 # 6
|
| 60 |
+
AUTOGRESSIVE: True
|
| 61 |
+
BACKBONE:
|
| 62 |
+
NAME: focal # focal_dw # focal
|
| 63 |
+
PRETRAINED: ''
|
| 64 |
+
LOAD_PRETRAINED: false
|
| 65 |
+
FOCAL:
|
| 66 |
+
PRETRAIN_IMG_SIZE: 224
|
| 67 |
+
PATCH_SIZE: 4
|
| 68 |
+
EMBED_DIM: 192 # 96 # 192
|
| 69 |
+
DEPTHS: [2, 2, 18, 2] # [2, 2, 6, 2] # [2, 2, 18, 2]
|
| 70 |
+
FOCAL_LEVELS: [4, 4, 4, 4] # [3, 3, 3, 3] # [4, 4, 4, 4]
|
| 71 |
+
FOCAL_WINDOWS: [3, 3, 3, 3]
|
| 72 |
+
DROP_PATH_RATE: 0.3
|
| 73 |
+
MLP_RATIO: 4.0
|
| 74 |
+
DROP_RATE: 0.0
|
| 75 |
+
PATCH_NORM: True
|
| 76 |
+
USE_CONV_EMBED: True
|
| 77 |
+
SCALING_MODULATOR: True
|
| 78 |
+
USE_CHECKPOINT: False
|
| 79 |
+
USE_POSTLN: true
|
| 80 |
+
USE_POSTLN_IN_MODULATION: false
|
| 81 |
+
USE_LAYERSCALE: True
|
| 82 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
| 83 |
+
OUT_INDICES: [0, 1, 2, 3]
|
| 84 |
+
ENCODER:
|
| 85 |
+
NAME: transformer_encoder_fpn
|
| 86 |
+
IGNORE_VALUE: 255
|
| 87 |
+
NUM_CLASSES: 16
|
| 88 |
+
BINARY_CLASSES: False
|
| 89 |
+
LOSS_WEIGHT: 1.0
|
| 90 |
+
CONVS_DIM: 512
|
| 91 |
+
MASK_DIM: 512
|
| 92 |
+
NORM: "GN"
|
| 93 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
| 94 |
+
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
| 95 |
+
COMMON_STRIDE: 4
|
| 96 |
+
TRANSFORMER_ENC_LAYERS: 6
|
| 97 |
+
DECODER:
|
| 98 |
+
NAME: seem_v1
|
| 99 |
+
TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
|
| 100 |
+
MASK:
|
| 101 |
+
ENABLED: True
|
| 102 |
+
DETECTION: False
|
| 103 |
+
SPATIAL:
|
| 104 |
+
ENABLED: True
|
| 105 |
+
MAX_ITER: 1
|
| 106 |
+
GROUNDING:
|
| 107 |
+
ENABLED: True
|
| 108 |
+
MAX_LEN: 10
|
| 109 |
+
TEXT_WEIGHT: 2.0
|
| 110 |
+
CLASS_WEIGHT: 0.5
|
| 111 |
+
RETRIEVAL:
|
| 112 |
+
ENABLED: False
|
| 113 |
+
LVIS:
|
| 114 |
+
ENABLED: False
|
| 115 |
+
THRES: 0.7
|
| 116 |
+
OPENIMAGE:
|
| 117 |
+
ENABLED: False
|
| 118 |
+
NEGATIVE_SAMPLES: 5
|
| 119 |
+
GROUNDING:
|
| 120 |
+
ENABLED: False
|
| 121 |
+
MAX_LEN: 5
|
| 122 |
+
CAPTION:
|
| 123 |
+
ENABLED: False
|
| 124 |
+
PHRASE_PROB: 0.5
|
| 125 |
+
SIM_THRES: 0.95
|
| 126 |
+
DEEP_SUPERVISION: True
|
| 127 |
+
NO_OBJECT_WEIGHT: 0.1
|
| 128 |
+
GCLASS_WEIGHT: 0.4
|
| 129 |
+
GMASK_WEIGHT: 1.0
|
| 130 |
+
GDICE_WEIGHT: 1.0
|
| 131 |
+
SCLASS_WEIGHT: 0.4
|
| 132 |
+
SMASK_WEIGHT: 1.0
|
| 133 |
+
SDICE_WEIGHT: 1.0
|
| 134 |
+
OCLASS_WEIGHT: 0.4
|
| 135 |
+
OMASK_WEIGHT: 1.0
|
| 136 |
+
ODICE_WEIGHT: 1.0
|
| 137 |
+
CLASS_WEIGHT: 2.0
|
| 138 |
+
MASK_WEIGHT: 5.0
|
| 139 |
+
DICE_WEIGHT: 5.0
|
| 140 |
+
BBOX_WEIGHT: 5.0
|
| 141 |
+
GIOU_WEIGHT: 2.0
|
| 142 |
+
CAPTION_WEIGHT: 2.0
|
| 143 |
+
COST_SPATIAL:
|
| 144 |
+
CLASS_WEIGHT: 5.0
|
| 145 |
+
MASK_WEIGHT: 2.0
|
| 146 |
+
DICE_WEIGHT: 2.0
|
| 147 |
+
HIDDEN_DIM: 512
|
| 148 |
+
NUM_OBJECT_QUERIES: 101
|
| 149 |
+
NHEADS: 8
|
| 150 |
+
DROPOUT: 0.0
|
| 151 |
+
DIM_FEEDFORWARD: 2048
|
| 152 |
+
MAX_SPATIAL_LEN: [512, 512, 512, 512]
|
| 153 |
+
# ENC_LAYERS: 0
|
| 154 |
+
PRE_NORM: False
|
| 155 |
+
ENFORCE_INPUT_PROJ: False
|
| 156 |
+
SIZE_DIVISIBILITY: 32
|
| 157 |
+
TRAIN_NUM_POINTS: 12544
|
| 158 |
+
OVERSAMPLE_RATIO: 3.0
|
| 159 |
+
IMPORTANCE_SAMPLE_RATIO: 0.75
|
| 160 |
+
DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
|
| 161 |
+
TOP_GROUNDING_LAYERS: 10
|
| 162 |
+
TOP_CAPTION_LAYERS: 10
|
| 163 |
+
TOP_SPATIAL_LAYERS: 10
|
| 164 |
+
TOP_OPENIMAGE_LAYERS: 10
|
| 165 |
+
TEST:
|
| 166 |
+
SEMANTIC_ON: False
|
| 167 |
+
INSTANCE_ON: False
|
| 168 |
+
PANOPTIC_ON: False
|
| 169 |
+
OVERLAP_THRESHOLD: 0.8
|
| 170 |
+
OBJECT_MASK_THRESHOLD: 0.8
|
| 171 |
+
SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: true
|
| 172 |
+
|
| 173 |
+
# Spatial sampler
|
| 174 |
+
STROKE_SAMPLER:
|
| 175 |
+
MAX_CANDIDATE: 1
|
| 176 |
+
CANDIDATE_PROBS: [0.25, 0.25, 0.25, 0.25] # for training only
|
| 177 |
+
CANDIDATE_NAMES: ["Point", "Polygon", "Scribble", "Circle"]
|
| 178 |
+
DILATION: 3
|
| 179 |
+
CIRCLE:
|
| 180 |
+
NUM_STROKES: 5
|
| 181 |
+
STROKE_PRESET: ['object_like', 'object_like_middle', 'object_like_small']
|
| 182 |
+
STROKE_PROB: [0.33, 0.33, 0.33]
|
| 183 |
+
SCRIBBLE:
|
| 184 |
+
NUM_STROKES: 5
|
| 185 |
+
STROKE_PRESET: ['rand_curve', 'rand_curve_small']
|
| 186 |
+
STROKE_PROB: [0.5, 0.5]
|
| 187 |
+
POINT:
|
| 188 |
+
NUM_POINTS: 20
|
| 189 |
+
POLYGON:
|
| 190 |
+
MAX_POINTS: 9
|
| 191 |
+
EVAL:
|
| 192 |
+
MODE: 'best' # best/random/best_random
|
| 193 |
+
NEGATIVE: False
|
| 194 |
+
MAX_ITER: 1
|
| 195 |
+
IOU_ITER: 1
|
| 196 |
+
GROUNDING: True
|
| 197 |
+
|
| 198 |
+
# Multi-modal Architecture, order matters
|
| 199 |
+
ATTENTION_ARCH:
|
| 200 |
+
VARIABLE:
|
| 201 |
+
queries: ['object', 'grounding', 'spatial']
|
| 202 |
+
tokens: ['grounding', 'spatial']
|
| 203 |
+
memories: ['spatial']
|
| 204 |
+
SELF_ATTENTION:
|
| 205 |
+
queries:
|
| 206 |
+
object: ['queries_object']
|
| 207 |
+
grounding: ['queries_grounding', 'tokens_grounding']
|
| 208 |
+
spatial: ['queries_spatial', 'tokens_spatial', 'memories_spatial']
|
| 209 |
+
tokens:
|
| 210 |
+
grounding: ['queries_grounding', 'tokens_grounding']
|
| 211 |
+
spatial: ['tokens_spatial']
|
| 212 |
+
memories:
|
| 213 |
+
spatial: ['memories_spatial']
|
| 214 |
+
CROSS_ATTENTION:
|
| 215 |
+
queries:
|
| 216 |
+
object: True
|
| 217 |
+
grounding: True
|
| 218 |
+
spatial: True
|
| 219 |
+
memories:
|
| 220 |
+
spatial: True
|
| 221 |
+
tokens:
|
| 222 |
+
grounding: False
|
| 223 |
+
spatial: False
|
| 224 |
+
MASKING: ['tokens_spatial', 'tokens_grounding']
|
| 225 |
+
DUPLICATION:
|
| 226 |
+
queries:
|
| 227 |
+
grounding: 'queries_object'
|
| 228 |
+
spatial: 'queries_object'
|
| 229 |
+
SPATIAL_MEMORIES: 32
|
| 230 |
+
QUERY_NUMBER: 3
|
| 231 |
+
|
| 232 |
+
DATASETS:
|
| 233 |
+
TRAIN: [
|
| 234 |
+
'biomed_BiomedParseData-Demo_demo' # Add your registered training datasets here
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
TEST: [
|
| 240 |
+
'biomed_BiomedParseData-Demo_demo' # Add your registered test datasets here
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
CLASS_CONCAT: false
|
| 244 |
+
SIZE_DIVISIBILITY: 32
|
| 245 |
+
PROPOSAL_FILES_TRAIN: []
|
| 246 |
+
|
| 247 |
+
INPUT:
|
| 248 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
| 249 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
| 250 |
+
|
| 251 |
+
TRAIN:
|
| 252 |
+
ASPECT_RATIO_GROUPING: true
|
| 253 |
+
BATCH_SIZE_TOTAL: 4
|
| 254 |
+
BATCH_SIZE_PER_GPU: 4
|
| 255 |
+
SHUFFLE: true
|
| 256 |
+
|
| 257 |
+
TEST:
|
| 258 |
+
DETECTIONS_PER_IMAGE: 100
|
| 259 |
+
NAME: coco_eval
|
| 260 |
+
IOU_TYPE: ['bbox', 'segm']
|
| 261 |
+
USE_MULTISCALE: false
|
| 262 |
+
BATCH_SIZE_TOTAL: 4
|
| 263 |
+
MODEL_FILE: ''
|
| 264 |
+
AUG:
|
| 265 |
+
ENABLED: False
|
| 266 |
+
|
| 267 |
+
DATALOADER:
|
| 268 |
+
FILTER_EMPTY_ANNOTATIONS: False
|
| 269 |
+
NUM_WORKERS: 8
|
| 270 |
+
LOAD_PROPOSALS: False
|
| 271 |
+
SAMPLER_TRAIN: "TrainingSampler"
|
| 272 |
+
ASPECT_RATIO_GROUPING: True
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
BioMed:
|
| 276 |
+
INPUT:
|
| 277 |
+
PIXEL_MEAN: [64.284, 59.293, 59.962]
|
| 278 |
+
PIXEL_STD: [62.484, 60.865, 59.835]
|
| 279 |
+
DATASET_MAPPER_NAME: "biomed_interactive"
|
| 280 |
+
MIN_SIZE_TRAIN: 900
|
| 281 |
+
MAX_SIZE_TRAIN: 1100
|
| 282 |
+
MIN_SIZE_TRAIN_SAMPLING: 'choice'
|
| 283 |
+
MIN_SIZE_TEST: 900
|
| 284 |
+
MAX_SIZE_TEST: 1100
|
| 285 |
+
IMAGE_SIZE: 1024
|
| 286 |
+
MIN_SCALE: 0.9
|
| 287 |
+
MAX_SCALE: 1.1
|
| 288 |
+
IGNORE_VALUE: 255
|
| 289 |
+
COLOR_AUG_SSD: False
|
| 290 |
+
SIZE_DIVISIBILITY: 32
|
| 291 |
+
RANDOM_FLIP: "none"
|
| 292 |
+
RANDOM_ROTATE: False
|
| 293 |
+
MASK_FORMAT: "polygon"
|
| 294 |
+
MIN_AREA: 30
|
| 295 |
+
FORMAT: "RGB"
|
| 296 |
+
SPATIAL: True
|
| 297 |
+
CROP:
|
| 298 |
+
ENABLED: True
|
| 299 |
+
DATASET:
|
| 300 |
+
DATASET: "biomed"
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# Detectron2 training config for optimizer and lr scheduler
|
| 304 |
+
SOLVER:
|
| 305 |
+
BASE_LR: 0.0001
|
| 306 |
+
STEPS: [0.88889, 0.96296]
|
| 307 |
+
MAX_ITER: 1
|
| 308 |
+
GAMMA: 0.1
|
| 309 |
+
WARMUP_FACTOR: 1.0
|
| 310 |
+
WARMUP_ITERS: 10
|
| 311 |
+
WARMUP_METHOD: "linear"
|
| 312 |
+
WEIGHT_DECAY: 0.05
|
| 313 |
+
OPTIMIZER: "ADAMW"
|
| 314 |
+
LR_SCHEDULER_NAME: "WarmupMultiStepLR"
|
| 315 |
+
LR_MULTIPLIER:
|
| 316 |
+
backbone: 0.1
|
| 317 |
+
lang_encoder: 0.1
|
| 318 |
+
FIX_PARAM:
|
| 319 |
+
backbone: True
|
| 320 |
+
lang_encoder: True
|
| 321 |
+
pixel_decoder: True
|
| 322 |
+
WEIGHT_DECAY_NORM: 0.0
|
| 323 |
+
WEIGHT_DECAY_EMBED: 0.0
|
| 324 |
+
CLIP_GRADIENTS:
|
| 325 |
+
ENABLED: True
|
| 326 |
+
CLIP_TYPE: "full_model"
|
| 327 |
+
CLIP_VALUE: 5.0 # 0.01
|
| 328 |
+
NORM_TYPE: 2.0
|
| 329 |
+
MAX_NUM_EPOCHS: 50
|
configs/biomedparse_inference.yaml
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Define Test/Trainer/Saving
|
| 2 |
+
PIPELINE: XDecoderPipeline
|
| 3 |
+
TRAINER: xdecoder
|
| 4 |
+
SAVE_DIR: '../../data/output/test'
|
| 5 |
+
base_path: "./"
|
| 6 |
+
|
| 7 |
+
# Resume Logistic
|
| 8 |
+
RESUME: false
|
| 9 |
+
WEIGHT: false
|
| 10 |
+
RESUME_FROM: ''
|
| 11 |
+
EVAL_AT_START: false
|
| 12 |
+
|
| 13 |
+
# Logging and Debug
|
| 14 |
+
WANDB: False
|
| 15 |
+
LOG_EVERY: 100
|
| 16 |
+
FIND_UNUSED_PARAMETERS: false
|
| 17 |
+
|
| 18 |
+
# Speed up training
|
| 19 |
+
FP16: false
|
| 20 |
+
PORT: '36873'
|
| 21 |
+
|
| 22 |
+
# misc
|
| 23 |
+
LOADER:
|
| 24 |
+
JOINT: False
|
| 25 |
+
KEY_DATASET: 'coco'
|
| 26 |
+
|
| 27 |
+
STANDARD_TEXT_FOR_EVAL: False
|
| 28 |
+
|
| 29 |
+
##################
|
| 30 |
+
# Task settings
|
| 31 |
+
##################
|
| 32 |
+
VERBOSE: true
|
| 33 |
+
MODEL:
|
| 34 |
+
NAME: seem_model_demo
|
| 35 |
+
HEAD: xdecoder_head
|
| 36 |
+
DIM_PROJ: 512
|
| 37 |
+
TEXT:
|
| 38 |
+
ARCH: vlpencoder
|
| 39 |
+
NAME: transformer
|
| 40 |
+
TOKENIZER: clip
|
| 41 |
+
CONTEXT_LENGTH: 77 # 77
|
| 42 |
+
WIDTH: 512
|
| 43 |
+
HEADS: 8
|
| 44 |
+
LAYERS: 12 # 6
|
| 45 |
+
AUTOGRESSIVE: True
|
| 46 |
+
BACKBONE:
|
| 47 |
+
NAME: focal
|
| 48 |
+
PRETRAINED: ''
|
| 49 |
+
LOAD_PRETRAINED: false
|
| 50 |
+
FOCAL:
|
| 51 |
+
PRETRAIN_IMG_SIZE: 224
|
| 52 |
+
PATCH_SIZE: 4
|
| 53 |
+
EMBED_DIM: 192
|
| 54 |
+
DEPTHS: [2, 2, 18, 2]
|
| 55 |
+
FOCAL_LEVELS: [4, 4, 4, 4]
|
| 56 |
+
FOCAL_WINDOWS: [3, 3, 3, 3]
|
| 57 |
+
DROP_PATH_RATE: 0.3
|
| 58 |
+
MLP_RATIO: 4.0
|
| 59 |
+
DROP_RATE: 0.0
|
| 60 |
+
PATCH_NORM: True
|
| 61 |
+
USE_CONV_EMBED: True
|
| 62 |
+
SCALING_MODULATOR: True
|
| 63 |
+
USE_CHECKPOINT: False
|
| 64 |
+
USE_POSTLN: true
|
| 65 |
+
USE_POSTLN_IN_MODULATION: false
|
| 66 |
+
USE_LAYERSCALE: True
|
| 67 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
| 68 |
+
OUT_INDICES: [0, 1, 2, 3]
|
| 69 |
+
ENCODER:
|
| 70 |
+
NAME: transformer_encoder_fpn
|
| 71 |
+
IGNORE_VALUE: 255
|
| 72 |
+
NUM_CLASSES: 16
|
| 73 |
+
BINARY_CLASSES: False
|
| 74 |
+
LOSS_WEIGHT: 1.0
|
| 75 |
+
CONVS_DIM: 512
|
| 76 |
+
MASK_DIM: 512
|
| 77 |
+
NORM: "GN"
|
| 78 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
| 79 |
+
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
| 80 |
+
COMMON_STRIDE: 4
|
| 81 |
+
TRANSFORMER_ENC_LAYERS: 6
|
| 82 |
+
DECODER:
|
| 83 |
+
NAME: seem_demo
|
| 84 |
+
TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
|
| 85 |
+
MASK:
|
| 86 |
+
ENABLED: False
|
| 87 |
+
DETECTION: False
|
| 88 |
+
SPATIAL:
|
| 89 |
+
ENABLED: True
|
| 90 |
+
MAX_ITER: 1
|
| 91 |
+
GROUNDING:
|
| 92 |
+
ENABLED: True
|
| 93 |
+
MAX_LEN: 5
|
| 94 |
+
TEXT_WEIGHT: 2.0
|
| 95 |
+
CLASS_WEIGHT: 0.5
|
| 96 |
+
VISUAL:
|
| 97 |
+
ENABLED: False
|
| 98 |
+
AUDIO:
|
| 99 |
+
ENABLED: False
|
| 100 |
+
RETRIEVAL:
|
| 101 |
+
ENABLED: False
|
| 102 |
+
LVIS:
|
| 103 |
+
ENABLED: True
|
| 104 |
+
THRES: 0.7
|
| 105 |
+
OPENIMAGE:
|
| 106 |
+
ENABLED: False
|
| 107 |
+
NEGATIVE_SAMPLES: 5
|
| 108 |
+
GROUNDING:
|
| 109 |
+
ENABLED: False
|
| 110 |
+
MAX_LEN: 5
|
| 111 |
+
CAPTION:
|
| 112 |
+
ENABLED: False
|
| 113 |
+
PHRASE_PROB: 0.5
|
| 114 |
+
SIM_THRES: 0.95
|
| 115 |
+
DEEP_SUPERVISION: True
|
| 116 |
+
NO_OBJECT_WEIGHT: 0.1
|
| 117 |
+
GCLASS_WEIGHT: 0.4
|
| 118 |
+
GMASK_WEIGHT: 1.0
|
| 119 |
+
GDICE_WEIGHT: 1.0
|
| 120 |
+
SCLASS_WEIGHT: 0.4
|
| 121 |
+
SMASK_WEIGHT: 1.0
|
| 122 |
+
SDICE_WEIGHT: 1.0
|
| 123 |
+
OCLASS_WEIGHT: 0.4
|
| 124 |
+
OMASK_WEIGHT: 1.0
|
| 125 |
+
ODICE_WEIGHT: 1.0
|
| 126 |
+
CLASS_WEIGHT: 2.0
|
| 127 |
+
MASK_WEIGHT: 5.0
|
| 128 |
+
DICE_WEIGHT: 5.0
|
| 129 |
+
BBOX_WEIGHT: 5.0
|
| 130 |
+
GIOU_WEIGHT: 2.0
|
| 131 |
+
CAPTION_WEIGHT: 2.0
|
| 132 |
+
COST_SPATIAL:
|
| 133 |
+
CLASS_WEIGHT: 5.0
|
| 134 |
+
MASK_WEIGHT: 2.0
|
| 135 |
+
DICE_WEIGHT: 2.0
|
| 136 |
+
HIDDEN_DIM: 512
|
| 137 |
+
NUM_OBJECT_QUERIES: 101
|
| 138 |
+
NHEADS: 8
|
| 139 |
+
DROPOUT: 0.0
|
| 140 |
+
DIM_FEEDFORWARD: 2048
|
| 141 |
+
MAX_SPATIAL_LEN: [512, 512, 512, 512]
|
| 142 |
+
# ENC_LAYERS: 0
|
| 143 |
+
PRE_NORM: False
|
| 144 |
+
ENFORCE_INPUT_PROJ: False
|
| 145 |
+
SIZE_DIVISIBILITY: 32
|
| 146 |
+
TRAIN_NUM_POINTS: 12544
|
| 147 |
+
OVERSAMPLE_RATIO: 3.0
|
| 148 |
+
IMPORTANCE_SAMPLE_RATIO: 0.75
|
| 149 |
+
DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
|
| 150 |
+
TOP_GROUNDING_LAYERS: 10
|
| 151 |
+
TOP_CAPTION_LAYERS: 10
|
| 152 |
+
TOP_SPATIAL_LAYERS: 10
|
| 153 |
+
TOP_OPENIMAGE_LAYERS: 10
|
| 154 |
+
TEST:
|
| 155 |
+
SEMANTIC_ON: True
|
| 156 |
+
INSTANCE_ON: True
|
| 157 |
+
PANOPTIC_ON: True
|
| 158 |
+
OVERLAP_THRESHOLD: 0.8
|
| 159 |
+
OBJECT_MASK_THRESHOLD: 0.4
|
| 160 |
+
SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false
|
| 161 |
+
DETECTIONS_PER_IMAGE: 100
|
| 162 |
+
|
| 163 |
+
# Multi-modal Architecture, order matters
|
| 164 |
+
ATTENTION_ARCH:
|
| 165 |
+
VARIABLE:
|
| 166 |
+
queries: ['object']
|
| 167 |
+
tokens: ['grounding', 'spatial', 'visual', 'audio']
|
| 168 |
+
SELF_ATTENTION:
|
| 169 |
+
queries:
|
| 170 |
+
object: ['queries_object', 'tokens_grounding', 'tokens_spatial', 'tokens_visual', 'tokens_audio']
|
| 171 |
+
tokens:
|
| 172 |
+
grounding: ['queries_object', 'tokens_grounding']
|
| 173 |
+
spatial: ['tokens_spatial']
|
| 174 |
+
visual: ['tokens_visual']
|
| 175 |
+
audio: ['queries_object', 'tokens_audio']
|
| 176 |
+
CROSS_ATTENTION:
|
| 177 |
+
queries:
|
| 178 |
+
object: True
|
| 179 |
+
tokens:
|
| 180 |
+
grounding: False
|
| 181 |
+
spatial: False
|
| 182 |
+
visual: False
|
| 183 |
+
audio: False
|
| 184 |
+
MASKING: ['tokens_spatial', 'tokens_grounding', 'tokens_visual', 'tokens_audio']
|
| 185 |
+
DUPLICATION:
|
| 186 |
+
queries:
|
| 187 |
+
grounding: 'queries_object'
|
| 188 |
+
spatial: 'queries_object'
|
| 189 |
+
SPATIAL_MEMORIES: 32
|
| 190 |
+
|
| 191 |
+
INPUT:
|
| 192 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
| 193 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
| 194 |
+
# INPUT:
|
| 195 |
+
# PIXEL_MEAN: [64.284, 59.293, 59.962]
|
| 196 |
+
# PIXEL_STD: [62.484, 60.865, 59.835]
|
datasets/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import registration
|
| 2 |
+
from .build import build_train_dataloader, build_eval_dataloader, build_evaluator
|
datasets/build.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
| 3 |
+
# Copyright (c) 2022 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Modified by Xueyan Zou ([email protected])
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
import itertools
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.utils.data
|
| 17 |
+
import torch.utils.data as torchdata
|
| 18 |
+
|
| 19 |
+
import detectron2.utils.comm as comm
|
| 20 |
+
from detectron2.data.build import (
|
| 21 |
+
build_batch_data_loader,
|
| 22 |
+
load_proposals_into_dataset,
|
| 23 |
+
trivial_batch_collator,
|
| 24 |
+
)
|
| 25 |
+
from detectron2.data import MetadataCatalog
|
| 26 |
+
from detectron2.data.catalog import DatasetCatalog
|
| 27 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
| 28 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
| 29 |
+
from detectron2.data.samplers import InferenceSampler, TrainingSampler
|
| 30 |
+
from detectron2.evaluation import (
|
| 31 |
+
CityscapesInstanceEvaluator,
|
| 32 |
+
CityscapesSemSegEvaluator,
|
| 33 |
+
COCOEvaluator,
|
| 34 |
+
DatasetEvaluators,
|
| 35 |
+
LVISEvaluator,
|
| 36 |
+
verify_results,
|
| 37 |
+
)
|
| 38 |
+
from fvcore.common.config import CfgNode
|
| 39 |
+
|
| 40 |
+
from .dataset_mappers import *
|
| 41 |
+
from .evaluation import (InstanceSegEvaluator,
|
| 42 |
+
ClassificationEvaluator,
|
| 43 |
+
SemSegEvaluator,
|
| 44 |
+
RetrievalEvaluator,
|
| 45 |
+
#CaptioningEvaluator,
|
| 46 |
+
COCOPanopticEvaluator,
|
| 47 |
+
GroundingEvaluator,
|
| 48 |
+
InteractiveEvaluator,
|
| 49 |
+
)
|
| 50 |
+
from modeling.utils import configurable
|
| 51 |
+
from utilities.distributed import get_world_size
|
| 52 |
+
|
| 53 |
+
class JointLoader(torchdata.IterableDataset):
|
| 54 |
+
"""
|
| 55 |
+
Randomly sampple from one of the dataloaders per worker in each iteration.
|
| 56 |
+
The sampling probability is determined by the size of each dataset.
|
| 57 |
+
All examples from one worker (GPU) are from the same dataset in the iteration.
|
| 58 |
+
Mixing is achieved through multiple workers (GPUs).
|
| 59 |
+
"""
|
| 60 |
+
def __init__(self, loaders, key_dataset, sample_prob, mixing_level):
|
| 61 |
+
dataset_names = []
|
| 62 |
+
for key, loader in loaders.items():
|
| 63 |
+
name = "{}".format(key.split('_')[0])
|
| 64 |
+
setattr(self, name, loader)
|
| 65 |
+
dataset_names += [name]
|
| 66 |
+
self.dataset_names = dataset_names
|
| 67 |
+
self.key_dataset = key_dataset
|
| 68 |
+
if sample_prob == 'prop':
|
| 69 |
+
self.sample_prob = [len(getattr(self, key)) for key in self.dataset_names]
|
| 70 |
+
elif sample_prob == 'equal':
|
| 71 |
+
self.sample_prob = [1 for key in self.dataset_names]
|
| 72 |
+
elif sample_prob == 'sqrt':
|
| 73 |
+
self.sample_prob = [np.sqrt(len(getattr(self, key))) for key in self.dataset_names]
|
| 74 |
+
self.sample_prob = [p/sum(self.sample_prob) for p in self.sample_prob]
|
| 75 |
+
self.mixing_level = mixing_level
|
| 76 |
+
|
| 77 |
+
# Not sure how expensive `len(getattr(self, name))` is. computing this once and cache.
|
| 78 |
+
# this assumes the len of the underlying data loaders do not change.
|
| 79 |
+
self._len = sum(len(getattr(self, name)) for name in self.dataset_names)
|
| 80 |
+
|
| 81 |
+
def __iter__(self):
|
| 82 |
+
# Reset iterators at the start of each new epoch
|
| 83 |
+
self.iterators = {name: iter(getattr(self, name)) for name in self.dataset_names}
|
| 84 |
+
self._count = 0
|
| 85 |
+
return self
|
| 86 |
+
|
| 87 |
+
def __next__(self):
|
| 88 |
+
while self._count < self._len:
|
| 89 |
+
# Randomly select a dataloader
|
| 90 |
+
name = np.random.choice(self.dataset_names, size=None, replace=False, p=self.sample_prob)
|
| 91 |
+
iterator = self.iterators[name]
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
# Get next batch from the selected dataloader
|
| 95 |
+
self._count += 1
|
| 96 |
+
return next(iterator)
|
| 97 |
+
except StopIteration:
|
| 98 |
+
# If the selected dataloader is exhausted, reinitialize it
|
| 99 |
+
self.iterators[name] = iter(getattr(self, name))
|
| 100 |
+
raise StopIteration
|
| 101 |
+
|
| 102 |
+
def __len__(self):
|
| 103 |
+
return self._len
|
| 104 |
+
|
| 105 |
+
def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names):
|
| 106 |
+
"""
|
| 107 |
+
Filter out images with none annotations or only crowd annotations
|
| 108 |
+
(i.e., images without non-crowd annotations).
|
| 109 |
+
A common training-time preprocessing on COCO dataset.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
list[dict]: the same format, but filtered.
|
| 116 |
+
"""
|
| 117 |
+
num_before = len(dataset_dicts)
|
| 118 |
+
|
| 119 |
+
def valid(anns):
|
| 120 |
+
for ann in anns:
|
| 121 |
+
if isinstance(ann, list):
|
| 122 |
+
for instance in ann:
|
| 123 |
+
if instance.get("iscrowd", 0) == 0:
|
| 124 |
+
return True
|
| 125 |
+
else:
|
| 126 |
+
if ann.get("iscrowd", 0) == 0:
|
| 127 |
+
return True
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
|
| 131 |
+
num_after = len(dataset_dicts)
|
| 132 |
+
logger = logging.getLogger(__name__)
|
| 133 |
+
logger.info(
|
| 134 |
+
"Removed {} images with no usable annotations. {} images left.".format(
|
| 135 |
+
num_before - num_after, num_after
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
return dataset_dicts
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_detection_dataset_dicts(
|
| 142 |
+
dataset_names, filter_empty=True, proposal_files=None
|
| 143 |
+
):
|
| 144 |
+
"""
|
| 145 |
+
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
dataset_names (str or list[str]): a dataset name or a list of dataset names
|
| 149 |
+
filter_empty (bool): whether to filter out images without instance annotations
|
| 150 |
+
proposal_files (list[str]): if given, a list of object proposal files
|
| 151 |
+
that match each dataset in `dataset_names`.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
list[dict]: a list of dicts following the standard dataset dict format.
|
| 155 |
+
"""
|
| 156 |
+
if isinstance(dataset_names, str):
|
| 157 |
+
dataset_names = [dataset_names]
|
| 158 |
+
assert len(dataset_names)
|
| 159 |
+
|
| 160 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
|
| 161 |
+
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
|
| 162 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
| 163 |
+
|
| 164 |
+
if proposal_files is not None:
|
| 165 |
+
assert len(dataset_names) == len(proposal_files)
|
| 166 |
+
# load precomputed proposals from proposal files
|
| 167 |
+
dataset_dicts = [
|
| 168 |
+
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
|
| 169 |
+
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
| 173 |
+
|
| 174 |
+
has_instances = "annotations" in dataset_dicts[0]
|
| 175 |
+
if filter_empty and has_instances:
|
| 176 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names)
|
| 177 |
+
|
| 178 |
+
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names))
|
| 179 |
+
return dataset_dicts
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _test_loader_from_config(cfg, dataset_name, mapper=None):
|
| 183 |
+
"""
|
| 184 |
+
Uses the given `dataset_name` argument (instead of the names in cfg), because the
|
| 185 |
+
standard practice is to evaluate each test set individually (not combining them).
|
| 186 |
+
"""
|
| 187 |
+
if isinstance(dataset_name, str):
|
| 188 |
+
dataset_name = [dataset_name]
|
| 189 |
+
|
| 190 |
+
dataset = get_detection_dataset_dicts(
|
| 191 |
+
dataset_name,
|
| 192 |
+
filter_empty=False,
|
| 193 |
+
proposal_files=None,
|
| 194 |
+
)
|
| 195 |
+
if mapper is None:
|
| 196 |
+
mapper_cfg = CfgNode({'INPUT': cfg['INPUT'], 'MODEL': cfg['MODEL'], 'DATASETS': cfg['DATASETS']})
|
| 197 |
+
mapper = DatasetMapper(mapper_cfg, False)
|
| 198 |
+
assert cfg['TEST']['BATCH_SIZE_TOTAL'] % get_world_size() == 0, "Evaluation total batchsize is not divisible by gpu number"
|
| 199 |
+
#batch_size = cfg['TEST']['BATCH_SIZE_TOTAL'] // get_world_size()
|
| 200 |
+
batch_size = 1
|
| 201 |
+
|
| 202 |
+
return {
|
| 203 |
+
"dataset": dataset,
|
| 204 |
+
"mapper": mapper,
|
| 205 |
+
"num_workers": cfg['DATALOADER']['NUM_WORKERS'],
|
| 206 |
+
"sampler": InferenceSampler(len(dataset)),
|
| 207 |
+
"batch_size": batch_size,
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@configurable(from_config=_test_loader_from_config)
|
| 212 |
+
def build_detection_test_loader(
|
| 213 |
+
dataset: Union[List[Any], torchdata.Dataset],
|
| 214 |
+
*,
|
| 215 |
+
mapper: Callable[[Dict[str, Any]], Any],
|
| 216 |
+
sampler: Optional[torchdata.Sampler] = None,
|
| 217 |
+
batch_size: int = 1,
|
| 218 |
+
num_workers: int = 0,
|
| 219 |
+
collate_fn: Optional[Callable[[List[Any]], Any]] = None,
|
| 220 |
+
) -> torchdata.DataLoader:
|
| 221 |
+
"""
|
| 222 |
+
Similar to `build_detection_train_loader`, with default batch size = 1,
|
| 223 |
+
and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
|
| 224 |
+
to produce the exact set of all samples.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
dataset: a list of dataset dicts,
|
| 228 |
+
or a pytorch dataset (either map-style or iterable). They can be obtained
|
| 229 |
+
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
| 230 |
+
mapper: a callable which takes a sample (dict) from dataset
|
| 231 |
+
and returns the format to be consumed by the model.
|
| 232 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
|
| 233 |
+
sampler: a sampler that produces
|
| 234 |
+
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
|
| 235 |
+
which splits the dataset across all workers. Sampler must be None
|
| 236 |
+
if `dataset` is iterable.
|
| 237 |
+
batch_size: the batch size of the data loader to be created.
|
| 238 |
+
Default to 1 image per worker since this is the standard when reporting
|
| 239 |
+
inference time in papers.
|
| 240 |
+
num_workers: number of parallel data loading workers
|
| 241 |
+
collate_fn: same as the argument of `torch.utils.data.DataLoader`.
|
| 242 |
+
Defaults to do no collation and return a list of data.
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
| 246 |
+
dataset, with test-time transformation and batching.
|
| 247 |
+
|
| 248 |
+
Examples:
|
| 249 |
+
::
|
| 250 |
+
data_loader = build_detection_test_loader(
|
| 251 |
+
DatasetRegistry.get("my_test"),
|
| 252 |
+
mapper=DatasetMapper(...))
|
| 253 |
+
|
| 254 |
+
# or, instantiate with a CfgNode:
|
| 255 |
+
data_loader = build_detection_test_loader(cfg, "my_test")
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
if isinstance(dataset, list):
|
| 259 |
+
dataset = DatasetFromList(dataset, copy=False)
|
| 260 |
+
if mapper is not None:
|
| 261 |
+
dataset = MapDataset(dataset, mapper)
|
| 262 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
| 263 |
+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
|
| 264 |
+
else:
|
| 265 |
+
if sampler is None:
|
| 266 |
+
sampler = InferenceSampler(len(dataset))
|
| 267 |
+
return torchdata.DataLoader(
|
| 268 |
+
dataset,
|
| 269 |
+
batch_size=batch_size,
|
| 270 |
+
sampler=sampler,
|
| 271 |
+
drop_last=False,
|
| 272 |
+
num_workers=num_workers,
|
| 273 |
+
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _train_loader_from_config(cfg, dataset_name, mapper, *, dataset=None, sampler=None):
|
| 278 |
+
cfg_datasets = cfg['DATASETS']
|
| 279 |
+
cfg_dataloader = cfg['DATALOADER']
|
| 280 |
+
|
| 281 |
+
if dataset is None:
|
| 282 |
+
dataset = get_detection_dataset_dicts(
|
| 283 |
+
dataset_name,
|
| 284 |
+
filter_empty=cfg_dataloader['FILTER_EMPTY_ANNOTATIONS'],
|
| 285 |
+
proposal_files=cfg_datasets['PROPOSAL_FILES_TRAIN'] if cfg_dataloader['LOAD_PROPOSALS'] else None,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if mapper is None:
|
| 289 |
+
mapper = DatasetMapper(cfg, True)
|
| 290 |
+
|
| 291 |
+
if sampler is None:
|
| 292 |
+
sampler_name = cfg_dataloader['SAMPLER_TRAIN']
|
| 293 |
+
logger = logging.getLogger(__name__)
|
| 294 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
| 295 |
+
sampler = TrainingSampler(len(dataset))
|
| 296 |
+
|
| 297 |
+
return {
|
| 298 |
+
"dataset": dataset,
|
| 299 |
+
"sampler": sampler,
|
| 300 |
+
"mapper": mapper,
|
| 301 |
+
"total_batch_size": cfg['TRAIN']['BATCH_SIZE_TOTAL'],
|
| 302 |
+
"aspect_ratio_grouping": cfg_dataloader['ASPECT_RATIO_GROUPING'],
|
| 303 |
+
"num_workers": cfg_dataloader['NUM_WORKERS'],
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@configurable(from_config=_train_loader_from_config)
|
| 308 |
+
def build_detection_train_loader(
|
| 309 |
+
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
|
| 310 |
+
):
|
| 311 |
+
"""
|
| 312 |
+
Build a dataloader for object detection with some default features.
|
| 313 |
+
This interface is experimental.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
| 317 |
+
or a map-style pytorch dataset. They can be obtained by using
|
| 318 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
| 319 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
| 320 |
+
returns the format to be consumed by the model.
|
| 321 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
|
| 322 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that
|
| 323 |
+
produces indices to be applied on ``dataset``.
|
| 324 |
+
Default to :class:`TrainingSampler`, which coordinates a random shuffle
|
| 325 |
+
sequence across all workers.
|
| 326 |
+
total_batch_size (int): total batch size across all workers. Batching
|
| 327 |
+
simply puts data into a list.
|
| 328 |
+
aspect_ratio_grouping (bool): whether to group images with similar
|
| 329 |
+
aspect ratio for efficiency. When enabled, it requires each
|
| 330 |
+
element in dataset be a dict with keys "width" and "height".
|
| 331 |
+
num_workers (int): number of parallel data loading workers
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
torch.utils.data.DataLoader: a dataloader. Each output from it is a
|
| 335 |
+
``list[mapped_element]`` of length ``total_batch_size / num_workers``,
|
| 336 |
+
where ``mapped_element`` is produced by the ``mapper``.
|
| 337 |
+
"""
|
| 338 |
+
if isinstance(dataset, list):
|
| 339 |
+
dataset = DatasetFromList(dataset, copy=False)
|
| 340 |
+
if mapper is not None:
|
| 341 |
+
dataset = MapDataset(dataset, mapper)
|
| 342 |
+
if sampler is None:
|
| 343 |
+
sampler = TrainingSampler(len(dataset))
|
| 344 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
| 345 |
+
return build_batch_data_loader(
|
| 346 |
+
dataset,
|
| 347 |
+
sampler,
|
| 348 |
+
total_batch_size,
|
| 349 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
| 350 |
+
num_workers=num_workers,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def get_config_from_name(cfg, dataset_name):
|
| 355 |
+
# adjust config according to dataset
|
| 356 |
+
if 'refcoco' in dataset_name:
|
| 357 |
+
cfg.update(cfg['REF'])
|
| 358 |
+
return cfg
|
| 359 |
+
elif 'cocomini' in dataset_name:
|
| 360 |
+
cfg.update(cfg['DAVIS'])
|
| 361 |
+
return cfg
|
| 362 |
+
elif 'ytvos' in dataset_name:
|
| 363 |
+
cfg.update(cfg['VOS'])
|
| 364 |
+
return cfg
|
| 365 |
+
elif 'ade600' in dataset_name:
|
| 366 |
+
cfg.update(cfg['DAVIS'])
|
| 367 |
+
return cfg
|
| 368 |
+
elif 'openimage600' in dataset_name:
|
| 369 |
+
cfg.update(cfg['DAVIS'])
|
| 370 |
+
return cfg
|
| 371 |
+
elif 'ade' in dataset_name:
|
| 372 |
+
if 'ADE20K' in cfg.keys():
|
| 373 |
+
cfg.update(cfg['ADE20K'])
|
| 374 |
+
return cfg
|
| 375 |
+
elif 'imagenet' in dataset_name:
|
| 376 |
+
if 'IMAGENET' in cfg.keys():
|
| 377 |
+
cfg.update(cfg['IMAGENET'])
|
| 378 |
+
return cfg
|
| 379 |
+
elif 'vlp' in dataset_name:
|
| 380 |
+
cfg.update(cfg['VLP'])
|
| 381 |
+
return cfg
|
| 382 |
+
elif 'coco' in dataset_name:
|
| 383 |
+
if 'COCO' in cfg.keys():
|
| 384 |
+
cfg.update(cfg['COCO'])
|
| 385 |
+
return cfg
|
| 386 |
+
elif 'voc' in dataset_name:
|
| 387 |
+
cfg.update(cfg['VOC'])
|
| 388 |
+
return cfg
|
| 389 |
+
elif 'context' in dataset_name:
|
| 390 |
+
cfg.update(cfg['CONTEXT'])
|
| 391 |
+
return cfg
|
| 392 |
+
elif 'sun' in dataset_name:
|
| 393 |
+
cfg.update(cfg['SUN'])
|
| 394 |
+
return cfg
|
| 395 |
+
elif 'scan' in dataset_name:
|
| 396 |
+
cfg.update(cfg['SCAN'])
|
| 397 |
+
return cfg
|
| 398 |
+
elif 'cityscape' in dataset_name:
|
| 399 |
+
cfg.update(cfg['CITY'])
|
| 400 |
+
return cfg
|
| 401 |
+
elif 'bdd' in dataset_name:
|
| 402 |
+
cfg.update(cfg['BDD'])
|
| 403 |
+
return cfg
|
| 404 |
+
elif 'tsv' in dataset_name:
|
| 405 |
+
cfg.update(cfg['TSV'])
|
| 406 |
+
return cfg
|
| 407 |
+
elif 'phrasecut' in dataset_name:
|
| 408 |
+
cfg.update(cfg['PHRASE'])
|
| 409 |
+
return cfg
|
| 410 |
+
elif 'object365' in dataset_name:
|
| 411 |
+
cfg.update(cfg['OBJECT365'])
|
| 412 |
+
return cfg
|
| 413 |
+
elif 'openimage' in dataset_name:
|
| 414 |
+
cfg.update(cfg['OPENIMAGE'])
|
| 415 |
+
return cfg
|
| 416 |
+
elif 'lvis' in dataset_name:
|
| 417 |
+
cfg.update(cfg['LVIS'])
|
| 418 |
+
return cfg
|
| 419 |
+
elif 'seginw' in dataset_name:
|
| 420 |
+
cfg.update(cfg['SEGINW'])
|
| 421 |
+
return cfg
|
| 422 |
+
elif 'sbd' in dataset_name:
|
| 423 |
+
cfg.update(cfg['SBD'])
|
| 424 |
+
return cfg
|
| 425 |
+
elif 'davis' in dataset_name:
|
| 426 |
+
cfg.update(cfg['DAVIS'])
|
| 427 |
+
return cfg
|
| 428 |
+
elif 'med_sam' in dataset_name:
|
| 429 |
+
cfg.update(cfg['MedSAM'])
|
| 430 |
+
return cfg
|
| 431 |
+
elif 'biomed' in dataset_name:
|
| 432 |
+
cfg.update(cfg['BioMed'])
|
| 433 |
+
return cfg
|
| 434 |
+
elif 'sam' in dataset_name:
|
| 435 |
+
cfg.update(cfg['SAM'])
|
| 436 |
+
return cfg
|
| 437 |
+
else:
|
| 438 |
+
assert False, "dataset not support."
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def build_eval_dataloader(cfg, ):
|
| 442 |
+
dataloaders = []
|
| 443 |
+
for dataset_name in cfg['DATASETS']['TEST']:
|
| 444 |
+
cfg = get_config_from_name(cfg, dataset_name)
|
| 445 |
+
# adjust mapper according to dataset
|
| 446 |
+
if dataset_name == 'imagenet_val':
|
| 447 |
+
mapper = ImageNetDatasetMapper(cfg, False)
|
| 448 |
+
elif dataset_name == 'bdd10k_val_sem_seg':
|
| 449 |
+
mapper = BDDSemDatasetMapper(cfg, False)
|
| 450 |
+
elif dataset_name in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017"]:
|
| 451 |
+
mapper = VLPreDatasetMapper(cfg, False, dataset_name)
|
| 452 |
+
elif dataset_name in ["scannet_21_val_seg", "scannet_38_val_seg", "scannet_41_val_seg"]:
|
| 453 |
+
mapper = ScanNetSegDatasetMapper(cfg, False)
|
| 454 |
+
elif dataset_name in ["scannet_21_panoptic_val", 'bdd10k_40_panoptic_val']:
|
| 455 |
+
mapper = ScanNetPanoDatasetMapper(cfg, False)
|
| 456 |
+
elif "pascalvoc_val" in dataset_name:
|
| 457 |
+
mapper = PascalVOCSegDatasetMapperIX(cfg, False, dataset_name)
|
| 458 |
+
elif 'sun' in dataset_name:
|
| 459 |
+
mapper = SunRGBDSegDatasetMapper(cfg, False)
|
| 460 |
+
elif 'refcoco' in dataset_name:
|
| 461 |
+
mapper = RefCOCODatasetMapper(cfg, False)
|
| 462 |
+
elif 'med_sam' in dataset_name:
|
| 463 |
+
mapper = MedSAMDatasetMapper(cfg, False)
|
| 464 |
+
elif 'biomed' in dataset_name:
|
| 465 |
+
mapper = BioMedDatasetMapper(cfg, False)
|
| 466 |
+
else:
|
| 467 |
+
mapper = None
|
| 468 |
+
dataloaders += [build_detection_test_loader(cfg, dataset_name, mapper=mapper)]
|
| 469 |
+
return dataloaders
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def build_train_dataloader(cfg, ):
|
| 473 |
+
dataset_names = cfg['DATASETS']['TRAIN']
|
| 474 |
+
|
| 475 |
+
loaders = {}
|
| 476 |
+
for dataset_name in dataset_names:
|
| 477 |
+
cfg = get_config_from_name(cfg, dataset_name)
|
| 478 |
+
mapper_name = cfg['INPUT']['DATASET_MAPPER_NAME']
|
| 479 |
+
# Semantic segmentation dataset mapper
|
| 480 |
+
if mapper_name == "mask_former_semantic":
|
| 481 |
+
mapper = MaskFormerSemanticDatasetMapper(cfg, True)
|
| 482 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 483 |
+
# Panoptic segmentation dataset mapper
|
| 484 |
+
elif mapper_name == "mask_former_panoptic":
|
| 485 |
+
mapper = MaskFormerPanopticDatasetMapper(cfg, True)
|
| 486 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 487 |
+
# Instance segmentation dataset mapper
|
| 488 |
+
elif mapper_name == "mask_former_instance":
|
| 489 |
+
mapper = MaskFormerInstanceDatasetMapper(cfg, True)
|
| 490 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 491 |
+
# coco instance segmentation lsj new baseline
|
| 492 |
+
elif mapper_name == "coco_instance_lsj":
|
| 493 |
+
mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)
|
| 494 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 495 |
+
# coco panoptic segmentation lsj new baseline
|
| 496 |
+
elif mapper_name == "coco_panoptic_lsj":
|
| 497 |
+
mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True)
|
| 498 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 499 |
+
elif mapper_name == "vlpretrain":
|
| 500 |
+
mapper = VLPreDatasetMapper(cfg, True, dataset_name)
|
| 501 |
+
loaders['vlp'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 502 |
+
elif mapper_name == "refcoco":
|
| 503 |
+
mapper = RefCOCODatasetMapper(cfg, True)
|
| 504 |
+
loaders['ref'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 505 |
+
elif mapper_name == "coco_interactive":
|
| 506 |
+
mapper = COCOPanopticInteractiveDatasetMapper(cfg, True)
|
| 507 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 508 |
+
elif mapper_name == "medsam_interactive":
|
| 509 |
+
mapper = MedSAMDatasetMapper(cfg, True)
|
| 510 |
+
loaders['med_sam'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 511 |
+
elif mapper_name == "biomed_interactive":
|
| 512 |
+
mapper = BioMedDatasetMapper(cfg, True)
|
| 513 |
+
name_key = dataset_name.split("_")[1]
|
| 514 |
+
loaders[name_key] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 515 |
+
else:
|
| 516 |
+
mapper = None
|
| 517 |
+
loaders[dataset_name] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
| 518 |
+
|
| 519 |
+
if len(loaders) == 1 or not cfg['LOADER'].get('JOINT', False):
|
| 520 |
+
return list(loaders.values())[0]
|
| 521 |
+
else:
|
| 522 |
+
sample_prob = cfg['LOADER'].get('SAMPLE_PROB', 'prop')
|
| 523 |
+
mixing_level = cfg['LOADER'].get('MIXING_LEVEL', 1)
|
| 524 |
+
return JointLoader(loaders, key_dataset=cfg['LOADER'].get('KEY_DATASET', 'coco'), sample_prob=sample_prob, mixing_level=mixing_level)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def build_evaluator(cfg, dataset_name, output_folder=None):
|
| 528 |
+
"""
|
| 529 |
+
Create evaluator(s) for a given dataset.
|
| 530 |
+
This uses the special metadata "evaluator_type" associated with each
|
| 531 |
+
builtin dataset. For your own dataset, you can simply create an
|
| 532 |
+
evaluator manually in your script and do not have to worry about the
|
| 533 |
+
hacky if-else logic here.
|
| 534 |
+
"""
|
| 535 |
+
if output_folder is None:
|
| 536 |
+
output_folder = os.path.join(cfg["SAVE_DIR"], "inference")
|
| 537 |
+
evaluator_list = []
|
| 538 |
+
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
| 539 |
+
|
| 540 |
+
# semantic segmentation
|
| 541 |
+
if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
|
| 542 |
+
evaluator_list.append(
|
| 543 |
+
SemSegEvaluator(
|
| 544 |
+
dataset_name,
|
| 545 |
+
distributed=True,
|
| 546 |
+
output_dir=output_folder,
|
| 547 |
+
)
|
| 548 |
+
)
|
| 549 |
+
# instance segmentation
|
| 550 |
+
if evaluator_type == "coco":
|
| 551 |
+
evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
|
| 552 |
+
|
| 553 |
+
cfg_model_decoder_test = cfg["MODEL"]["DECODER"]["TEST"]
|
| 554 |
+
# panoptic segmentation
|
| 555 |
+
if evaluator_type in [
|
| 556 |
+
"coco_panoptic_seg",
|
| 557 |
+
"ade20k_panoptic_seg",
|
| 558 |
+
"cityscapes_panoptic_seg",
|
| 559 |
+
"mapillary_vistas_panoptic_seg",
|
| 560 |
+
"scannet_panoptic_seg",
|
| 561 |
+
"bdd_panoptic_pano"
|
| 562 |
+
]:
|
| 563 |
+
if cfg_model_decoder_test["PANOPTIC_ON"]:
|
| 564 |
+
evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
|
| 565 |
+
# COCO
|
| 566 |
+
if (evaluator_type == "coco_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]) or evaluator_type == "object365_od":
|
| 567 |
+
evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
|
| 568 |
+
if (evaluator_type == "coco_panoptic_seg" and cfg_model_decoder_test["SEMANTIC_ON"]) or evaluator_type == "coco_sem_seg":
|
| 569 |
+
evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
|
| 570 |
+
# Mapillary Vistas
|
| 571 |
+
if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]:
|
| 572 |
+
evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
|
| 573 |
+
if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg_model_decoder_test["SEMANTIC_ON"]:
|
| 574 |
+
evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
|
| 575 |
+
# Cityscapes
|
| 576 |
+
if evaluator_type == "cityscapes_instance":
|
| 577 |
+
assert (
|
| 578 |
+
torch.cuda.device_count() > comm.get_rank()
|
| 579 |
+
), "CityscapesEvaluator currently do not work with multiple machines."
|
| 580 |
+
return CityscapesInstanceEvaluator(dataset_name)
|
| 581 |
+
if evaluator_type == "cityscapes_sem_seg":
|
| 582 |
+
assert (
|
| 583 |
+
torch.cuda.device_count() > comm.get_rank()
|
| 584 |
+
), "CityscapesEvaluator currently do not work with multiple machines."
|
| 585 |
+
return CityscapesSemSegEvaluator(dataset_name)
|
| 586 |
+
if evaluator_type == "cityscapes_panoptic_seg":
|
| 587 |
+
if cfg_model_decoder_test["SEMANTIC_ON"]:
|
| 588 |
+
assert (
|
| 589 |
+
torch.cuda.device_count() > comm.get_rank()
|
| 590 |
+
), "CityscapesEvaluator currently do not work with multiple machines."
|
| 591 |
+
evaluator_list.append(CityscapesSemSegEvaluator(dataset_name))
|
| 592 |
+
if cfg_model_decoder_test["INSTANCE_ON"]:
|
| 593 |
+
assert (
|
| 594 |
+
torch.cuda.device_count() > comm.get_rank()
|
| 595 |
+
), "CityscapesEvaluator currently do not work with multiple machines."
|
| 596 |
+
evaluator_list.append(CityscapesInstanceEvaluator(dataset_name))
|
| 597 |
+
# ADE20K
|
| 598 |
+
if evaluator_type == "ade20k_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]:
|
| 599 |
+
evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
|
| 600 |
+
# SEGINW
|
| 601 |
+
if evaluator_type == "seginw" and cfg_model_decoder_test["INSTANCE_ON"]:
|
| 602 |
+
evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
|
| 603 |
+
# LVIS
|
| 604 |
+
if evaluator_type == "lvis":
|
| 605 |
+
return LVISEvaluator(dataset_name, output_dir=output_folder)
|
| 606 |
+
# Classification
|
| 607 |
+
if evaluator_type == "classification":
|
| 608 |
+
evaluator_list.append(ClassificationEvaluator(dataset_name, output_folder))
|
| 609 |
+
# Retrieval
|
| 610 |
+
if evaluator_type in ["retrieval"]:
|
| 611 |
+
evaluator_list.append(RetrievalEvaluator(dataset_name, output_folder, cfg['MODEL']['DECODER']['RETRIEVAL']['ENSEMBLE']))
|
| 612 |
+
if evaluator_type == "captioning":
|
| 613 |
+
evaluator_list.append(CaptioningEvaluator(dataset_name, output_folder, MetadataCatalog.get(dataset_name).gt_json))
|
| 614 |
+
if evaluator_type in ["grounding_refcoco", "grounding_phrasecut", "grounding_spatial", "grounding_entity"]:
|
| 615 |
+
evaluator_list.append(GroundingEvaluator(dataset_name))
|
| 616 |
+
# Interactive
|
| 617 |
+
if evaluator_type in ["interactive", "interactive_grounding"]:
|
| 618 |
+
evaluator_list.append(InteractiveEvaluator(dataset_name, output_dir=output_folder, max_clicks=cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'], iou_iter=cfg['STROKE_SAMPLER']['EVAL']['IOU_ITER']))
|
| 619 |
+
|
| 620 |
+
if len(evaluator_list) == 0:
|
| 621 |
+
raise NotImplementedError(
|
| 622 |
+
"no Evaluator for the dataset {} with the type {}".format(
|
| 623 |
+
dataset_name, evaluator_type
|
| 624 |
+
)
|
| 625 |
+
)
|
| 626 |
+
elif len(evaluator_list) == 1:
|
| 627 |
+
return evaluator_list[0]
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
return DatasetEvaluators(evaluator_list)
|
datasets/dataset_mappers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .biomed_dataset_mapper import BioMedDatasetMapper
|
datasets/dataset_mappers/biomed_dataset_mapper.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py
|
| 3 |
+
import copy
|
| 4 |
+
import logging
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from transformers import AutoTokenizer, LlamaForCausalLM
|
| 11 |
+
|
| 12 |
+
from detectron2.data import detection_utils as utils
|
| 13 |
+
from detectron2.data import transforms as T
|
| 14 |
+
from detectron2.data.transforms import TransformGen
|
| 15 |
+
from detectron2.structures import BitMasks, Boxes, Instances, BoxMode
|
| 16 |
+
from detectron2.structures.boxes import pairwise_iou
|
| 17 |
+
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
|
| 18 |
+
from detectron2.data import MetadataCatalog
|
| 19 |
+
from pycocotools import mask as coco_mask
|
| 20 |
+
|
| 21 |
+
from utilities import prompt_engineering
|
| 22 |
+
from modeling.language import build_tokenizer
|
| 23 |
+
from modeling.language.misc import text_noun_with_prompt_all
|
| 24 |
+
from modeling.utils import configurable
|
| 25 |
+
|
| 26 |
+
from ..visual_sampler.sampler import build_shape_sampler
|
| 27 |
+
|
| 28 |
+
__all__ = ["BioMedDatasetMapper"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_transform_gen(cfg, is_train):
|
| 32 |
+
"""
|
| 33 |
+
Create a list of default :class:`Augmentation` from config.
|
| 34 |
+
Now it includes resizing and flipping.
|
| 35 |
+
Returns:
|
| 36 |
+
list[Augmentation]
|
| 37 |
+
"""
|
| 38 |
+
assert is_train, "Only support training augmentation"
|
| 39 |
+
cfg_input = cfg['INPUT']
|
| 40 |
+
image_size = cfg_input['IMAGE_SIZE']
|
| 41 |
+
min_scale = cfg_input['MIN_SCALE']
|
| 42 |
+
max_scale = cfg_input['MAX_SCALE']
|
| 43 |
+
|
| 44 |
+
augmentation = []
|
| 45 |
+
|
| 46 |
+
if cfg_input['RANDOM_FLIP'] != "none":
|
| 47 |
+
augmentation.append(
|
| 48 |
+
T.RandomFlip(
|
| 49 |
+
horizontal=cfg_input['RANDOM_FLIP'] == "horizontal",
|
| 50 |
+
vertical=cfg_input['RANDOM_FLIP'] == "vertical",
|
| 51 |
+
)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
augmentation.extend([
|
| 55 |
+
T.ResizeScale(
|
| 56 |
+
min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
|
| 57 |
+
),
|
| 58 |
+
T.FixedSizeCrop(crop_size=(image_size, image_size)),
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
return augmentation
|
| 62 |
+
|
| 63 |
+
def build_transform_gen_se(cfg, is_train):
|
| 64 |
+
# min_scale = cfg['INPUT']['MIN_SIZE_TEST']
|
| 65 |
+
# max_scale = cfg['INPUT']['MAX_SIZE_TEST']
|
| 66 |
+
|
| 67 |
+
augmentation = []
|
| 68 |
+
# augmentation.extend([
|
| 69 |
+
# T.ResizeShortestEdge(
|
| 70 |
+
# min_scale, max_size=max_scale
|
| 71 |
+
# ),
|
| 72 |
+
# ])
|
| 73 |
+
return augmentation
|
| 74 |
+
|
| 75 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
| 76 |
+
masks = []
|
| 77 |
+
for polygons in segmentations:
|
| 78 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
| 79 |
+
mask = coco_mask.decode(rles)
|
| 80 |
+
if len(mask.shape) < 3:
|
| 81 |
+
mask = mask[..., None]
|
| 82 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
| 83 |
+
mask = mask.any(dim=2)
|
| 84 |
+
masks.append(mask)
|
| 85 |
+
if masks:
|
| 86 |
+
masks = torch.stack(masks, dim=0)
|
| 87 |
+
else:
|
| 88 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
| 89 |
+
return masks
|
| 90 |
+
|
| 91 |
+
# This is specifically designed for the COCO dataset.
|
| 92 |
+
class BioMedDatasetMapper:
|
| 93 |
+
"""
|
| 94 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
| 95 |
+
and map it into a format used by MaskFormer.
|
| 96 |
+
|
| 97 |
+
This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
|
| 98 |
+
|
| 99 |
+
The callable currently does the following:
|
| 100 |
+
|
| 101 |
+
1. Read the image from "file_name"
|
| 102 |
+
2. Applies geometric transforms to the image and annotation
|
| 103 |
+
3. Find and applies suitable cropping to the image and annotation
|
| 104 |
+
4. Prepare image and annotation to Tensors
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
@configurable
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
is_train=True,
|
| 111 |
+
*,
|
| 112 |
+
tfm_gens,
|
| 113 |
+
image_format,
|
| 114 |
+
caption_thres,
|
| 115 |
+
grounding,
|
| 116 |
+
lvis,
|
| 117 |
+
lvis_thres,
|
| 118 |
+
max_grounding_num,
|
| 119 |
+
shape_sampler,
|
| 120 |
+
retrieval,
|
| 121 |
+
max_token_num,
|
| 122 |
+
tokenizer,
|
| 123 |
+
binary_classes: bool,
|
| 124 |
+
rotate: bool,
|
| 125 |
+
):
|
| 126 |
+
"""
|
| 127 |
+
NOTE: this interface is experimental.
|
| 128 |
+
Args:
|
| 129 |
+
is_train: for training or inference
|
| 130 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
| 131 |
+
crop_gen: crop augmentation
|
| 132 |
+
tfm_gens: data augmentation
|
| 133 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
| 134 |
+
"""
|
| 135 |
+
self.tfm_gens = tfm_gens
|
| 136 |
+
logging.getLogger(__name__).info(
|
| 137 |
+
"[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(
|
| 138 |
+
str(self.tfm_gens)
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
self.img_format = image_format
|
| 143 |
+
self.is_train = is_train
|
| 144 |
+
self.caption_thres = caption_thres
|
| 145 |
+
self.grounding = grounding
|
| 146 |
+
self.lvis = lvis
|
| 147 |
+
self.lvis_thres = lvis_thres
|
| 148 |
+
self.max_grounding_num = max_grounding_num
|
| 149 |
+
|
| 150 |
+
self.shape_sampler = shape_sampler
|
| 151 |
+
|
| 152 |
+
self.retrieval = retrieval
|
| 153 |
+
self.tokenizer = tokenizer
|
| 154 |
+
self.max_token_num = max_token_num
|
| 155 |
+
|
| 156 |
+
self.binary_classes = binary_classes
|
| 157 |
+
self.rotate = rotate
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def from_config(cls, cfg, is_train=True):
|
| 161 |
+
# Build augmentation
|
| 162 |
+
if is_train:
|
| 163 |
+
tfm_gens = build_transform_gen(cfg, is_train)
|
| 164 |
+
else:
|
| 165 |
+
tfm_gens = build_transform_gen_se(cfg, is_train)
|
| 166 |
+
|
| 167 |
+
shape_sampler = build_shape_sampler(cfg)
|
| 168 |
+
|
| 169 |
+
retrieval = cfg['MODEL']['DECODER']['RETRIEVAL']['ENABLED']
|
| 170 |
+
tokenizer, max_token_num = None, None
|
| 171 |
+
if retrieval:
|
| 172 |
+
lang_model = cfg['MODEL']['TEXT']['NAME']
|
| 173 |
+
max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
|
| 174 |
+
if 'llama' in lang_model:
|
| 175 |
+
tokenizer = AutoTokenizer.from_pretrained(lang_model, padding_side='right')
|
| 176 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 177 |
+
else:
|
| 178 |
+
tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
|
| 179 |
+
|
| 180 |
+
ret = {
|
| 181 |
+
"is_train": is_train,
|
| 182 |
+
"tfm_gens": tfm_gens,
|
| 183 |
+
"image_format": cfg['INPUT']['FORMAT'],
|
| 184 |
+
"caption_thres": cfg['MODEL']['DECODER']['CAPTION']['SIM_THRES'],
|
| 185 |
+
"grounding": cfg['MODEL']['DECODER']['GROUNDING']['ENABLED'],
|
| 186 |
+
"lvis": cfg['MODEL']['DECODER']['LVIS']['ENABLED'],
|
| 187 |
+
"lvis_thres": cfg['MODEL']['DECODER']['LVIS']['THRES'],
|
| 188 |
+
"max_grounding_num": cfg['MODEL']['DECODER']['GROUNDING']['MAX_LEN'],
|
| 189 |
+
"shape_sampler": shape_sampler,
|
| 190 |
+
"retrieval": retrieval,
|
| 191 |
+
"max_token_num": max_token_num,
|
| 192 |
+
"tokenizer": tokenizer,
|
| 193 |
+
"binary_classes": cfg['MODEL']['ENCODER']['BINARY_CLASSES'],
|
| 194 |
+
"rotate": cfg['INPUT']['RANDOM_ROTATE'],
|
| 195 |
+
}
|
| 196 |
+
return ret
|
| 197 |
+
|
| 198 |
+
def __call__(self, dataset_dict):
|
| 199 |
+
"""
|
| 200 |
+
Args:
|
| 201 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
dict: a format that builtin models in detectron2 accept
|
| 205 |
+
"""
|
| 206 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
| 207 |
+
while True:
|
| 208 |
+
try:
|
| 209 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
| 210 |
+
break
|
| 211 |
+
except:
|
| 212 |
+
print('Image loading error:', dataset_dict["file_name"])
|
| 213 |
+
|
| 214 |
+
utils.check_image_size(dataset_dict, image)
|
| 215 |
+
|
| 216 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
| 217 |
+
image_shape = image.shape[:2] # h, w
|
| 218 |
+
|
| 219 |
+
rotate_time = 0
|
| 220 |
+
if self.is_train and self.rotate and random.random() < 0.5:
|
| 221 |
+
rotate_time = random.randint(1, 3)
|
| 222 |
+
if rotate_time > 0:
|
| 223 |
+
image = np.rot90(image, rotate_time)
|
| 224 |
+
|
| 225 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
| 226 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
| 227 |
+
# Therefore it's important to use torch.Tensor.
|
| 228 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
grounding_anno = dataset_dict['grounding_info']
|
| 232 |
+
if len(grounding_anno) == 0:
|
| 233 |
+
print(dataset_dict['file_name'])
|
| 234 |
+
assert len(grounding_anno) > 0
|
| 235 |
+
masks_grd = []
|
| 236 |
+
texts_grd = []
|
| 237 |
+
boxes_grd = []
|
| 238 |
+
hash_grd = []
|
| 239 |
+
classes = []
|
| 240 |
+
masks_orig = []
|
| 241 |
+
for ann in grounding_anno:
|
| 242 |
+
if 'segmentation' in ann:
|
| 243 |
+
if len(ann['segmentation']) == 0:
|
| 244 |
+
print('Empty segmentation!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
| 245 |
+
continue
|
| 246 |
+
rle = coco_mask.frPyObjects(
|
| 247 |
+
ann['segmentation'], dataset_dict['height'], dataset_dict['width'])
|
| 248 |
+
m = coco_mask.decode(rle)
|
| 249 |
+
masks_orig.append(m)
|
| 250 |
+
# sometimes there are multiple binary map (corresponding to multiple segs)
|
| 251 |
+
m = np.sum(m, axis=2)
|
| 252 |
+
else:
|
| 253 |
+
# directly read from mask file
|
| 254 |
+
while True:
|
| 255 |
+
try:
|
| 256 |
+
m = utils.read_image(ann["mask_file"], format=self.img_format)
|
| 257 |
+
break
|
| 258 |
+
except:
|
| 259 |
+
print('Image loading error:', ann["mask_file"])
|
| 260 |
+
m = np.sum(m, axis=2)
|
| 261 |
+
m = 1 * (m > 0)
|
| 262 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
| 263 |
+
m = transforms.apply_segmentation(255*m[:,:,None])[:,:,0]
|
| 264 |
+
if rotate_time > 0:
|
| 265 |
+
m = np.rot90(m, rotate_time)
|
| 266 |
+
masks_grd += [m]
|
| 267 |
+
rand_id = random.randint(0, len(ann['sentences'])-1)
|
| 268 |
+
texts_grd.append(ann['sentences'][rand_id]['raw'].lower())
|
| 269 |
+
hash_grd.append(hash(ann['sentences'][rand_id]['raw'].lower()))
|
| 270 |
+
if self.binary_classes:
|
| 271 |
+
ann["category_id"] = 1 * (ann["category_id"] > 0)
|
| 272 |
+
classes.append(ann["category_id"])
|
| 273 |
+
#masks_grd = torch.from_numpy(np.stack(masks_grd))
|
| 274 |
+
boxes_grd = torch.tensor(boxes_grd)
|
| 275 |
+
groundings = {'masks': masks_grd, 'texts': texts_grd, 'hash': hash_grd, 'mode': 'text'}
|
| 276 |
+
dataset_dict["groundings"] = groundings
|
| 277 |
+
|
| 278 |
+
masks_grd = torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks_grd])
|
| 279 |
+
|
| 280 |
+
instances = Instances(image_shape)
|
| 281 |
+
|
| 282 |
+
instances.gt_masks = BitMasks(masks_grd)
|
| 283 |
+
instances.gt_boxes = BitMasks(masks_grd).get_bounding_boxes()
|
| 284 |
+
|
| 285 |
+
classes = np.array(classes)
|
| 286 |
+
is_things = np.array([1 for _ in classes])
|
| 287 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
| 288 |
+
instances.is_things = torch.tensor(is_things, dtype=torch.int64)
|
| 289 |
+
|
| 290 |
+
dataset_dict["instances"] = instances
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
spatial_query_utils = self.shape_sampler(instances)
|
| 294 |
+
dataset_dict['spatial_query'] = spatial_query_utils
|
| 295 |
+
|
| 296 |
+
if self.retrieval:
|
| 297 |
+
captions = dataset_dict['captions']
|
| 298 |
+
tokens = self.tokenizer(
|
| 299 |
+
captions, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
|
| 300 |
+
)
|
| 301 |
+
dataset_dict['tokens'] = {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}
|
| 302 |
+
|
| 303 |
+
if self.grounding:
|
| 304 |
+
grounding_anno = dataset_dict['grounding_info']
|
| 305 |
+
grounding_len = random.randint(1, self.max_grounding_num-1)
|
| 306 |
+
if len(grounding_anno) > 0:
|
| 307 |
+
masks_grd = []
|
| 308 |
+
texts_grd = []
|
| 309 |
+
mode = 'text'
|
| 310 |
+
random.shuffle(grounding_anno)
|
| 311 |
+
for ann in grounding_anno:
|
| 312 |
+
if 'segmentation' in ann:
|
| 313 |
+
if len(ann['segmentation']) == 0:
|
| 314 |
+
print('Empty segmentation!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
| 315 |
+
continue
|
| 316 |
+
rle = coco_mask.frPyObjects(
|
| 317 |
+
ann['segmentation'], dataset_dict['height'], dataset_dict['width'])
|
| 318 |
+
m = coco_mask.decode(rle)
|
| 319 |
+
# sometimes there are multiple binary map (corresponding to multiple segs)
|
| 320 |
+
m = np.sum(m, axis=2)
|
| 321 |
+
else:
|
| 322 |
+
# directly read from mask file
|
| 323 |
+
while True:
|
| 324 |
+
try:
|
| 325 |
+
m = utils.read_image(ann["mask_file"], format=self.img_format)
|
| 326 |
+
break
|
| 327 |
+
except:
|
| 328 |
+
print('Image loading error:', ann["mask_file"])
|
| 329 |
+
m = np.sum(m, axis=2)
|
| 330 |
+
m = 1 * (m > 0)
|
| 331 |
+
|
| 332 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
| 333 |
+
m = transforms.apply_segmentation(m[:,:,None])[:,:,0]
|
| 334 |
+
if rotate_time > 0:
|
| 335 |
+
m = np.rot90(m, rotate_time)
|
| 336 |
+
masks_grd += [m]
|
| 337 |
+
# random select a sentence of a single annotation.
|
| 338 |
+
rand_index = random.randint(0, len(ann['sentences'])-1)
|
| 339 |
+
texts_grd += [ann['sentences'][rand_index]['raw'].lower()]
|
| 340 |
+
# max_len = min(grounding_len, len(texts_grd))
|
| 341 |
+
max_len = len(masks_grd)
|
| 342 |
+
indices = np.random.permutation(max_len)
|
| 343 |
+
texts_grd = list(np.array(texts_grd)[indices])
|
| 344 |
+
masks_grd = torch.tensor(np.stack(masks_grd)[indices])
|
| 345 |
+
hash_grd = np.array([hash(txt) for txt in texts_grd])
|
| 346 |
+
else:
|
| 347 |
+
masks_grd = instances.gt_masks.tensor
|
| 348 |
+
mode = 'class'
|
| 349 |
+
if len(masks_grd) == 0:
|
| 350 |
+
masks_grd = torch.tensor([])
|
| 351 |
+
texts_grd = ['none']
|
| 352 |
+
hash_grd = np.array([hash(txt) for txt in texts_grd])
|
| 353 |
+
else:
|
| 354 |
+
biomed_classes = ['liver', 'lung', 'kidney', 'pancreas', 'heart anatomies', 'brain anatomies',
|
| 355 |
+
'eye anatomies', 'vessel', 'other organ', 'tumor', 'infection', 'other lesion',
|
| 356 |
+
'fluid disturbance', 'other abnormality', 'histology structure', 'other']
|
| 357 |
+
if self.binary_classes:
|
| 358 |
+
biomed_classes = ['target']
|
| 359 |
+
texts_grd = np.array(biomed_classes)
|
| 360 |
+
hash_grd = np.array([hash(txt) for txt in texts_grd])
|
| 361 |
+
unique_hash_grd = np.unique(hash_grd)
|
| 362 |
+
np.random.shuffle(unique_hash_grd)
|
| 363 |
+
max_len = min(grounding_len, len(unique_hash_grd))
|
| 364 |
+
indices = np.random.permutation(max_len)
|
| 365 |
+
selected_unique_hash_grd = unique_hash_grd[indices]
|
| 366 |
+
selected_mask = np.in1d(hash_grd, selected_unique_hash_grd)
|
| 367 |
+
texts_grd = texts_grd[selected_mask]
|
| 368 |
+
hash_grd = hash_grd[selected_mask]
|
| 369 |
+
masks_grd = masks_grd[selected_mask]
|
| 370 |
+
texts_grd = [prompt_engineering(text.replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
|
| 371 |
+
for text in texts_grd]
|
| 372 |
+
groundings = {'masks': masks_grd, 'texts': texts_grd, 'mode': mode, 'hash': hash_grd}
|
| 373 |
+
dataset_dict["groundings"] = groundings
|
| 374 |
+
assert len(masks_grd) == len(dataset_dict['grounding_info']), f"len(masks_grd)={len(masks_grd)}, len(dataset_dict['grounding_info'])={len(dataset_dict['grounding_info'])}, mask shape={masks_grd.shape}, max_len={max_len}, grounding_len={grounding_len}, len(texts_grd)={len(texts_grd)}, len(hash_grd)={len(hash_grd)}"
|
| 375 |
+
# gt_masks_orisize = torch.stack([torch.from_numpy(m.squeeze(-1)) for m in masks_orig])
|
| 376 |
+
# dataset_dict['gt_masks_orisize'] = gt_masks_orisize # (nm,h,w)
|
| 377 |
+
|
| 378 |
+
return dataset_dict
|
datasets/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .instance_evaluation import *
|
| 2 |
+
from .classification_evaluation import *
|
| 3 |
+
from .segmentation_evaluation import *
|
| 4 |
+
from .retrieval_evaluation import *
|
| 5 |
+
#from .captioning_evaluation import *
|
| 6 |
+
from .panoptic_evaluation import *
|
| 7 |
+
from .grounding_evaluation import *
|
| 8 |
+
from .interactive_evaluation import *
|
datasets/evaluation/captioning_evaluation.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# --------------------------------------------------------
|
| 3 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Modified by Xueyan Zou ([email protected])
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import logging
|
| 12 |
+
import itertools
|
| 13 |
+
|
| 14 |
+
import detectron2.utils.comm as comm
|
| 15 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
| 16 |
+
|
| 17 |
+
from caption_pycocotools.coco import COCO
|
| 18 |
+
from pycocoevalcap.eval import COCOEvalCap
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CaptioningEvaluator(DatasetEvaluator):
|
| 22 |
+
"""
|
| 23 |
+
Evaluate AR for object proposals, AP for instance detection/segmentation, AP
|
| 24 |
+
for keypoint detection outputs using COCO's metrics.
|
| 25 |
+
See http://cocodataset.org/#detection-eval and
|
| 26 |
+
http://cocodataset.org/#keypoints-eval to understand its metrics.
|
| 27 |
+
The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
|
| 28 |
+
the metric cannot be computed (e.g. due to no predictions made).
|
| 29 |
+
In addition to COCO, this evaluator is able to support any bounding box detection,
|
| 30 |
+
instance segmentation, or keypoint detection dataset.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
distributed=True,
|
| 36 |
+
output_dir=None,
|
| 37 |
+
gt_json=None,
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
Args:
|
| 41 |
+
dataset_name (str): name of the dataset to be evaluated.
|
| 42 |
+
It must have either the following corresponding metadata:
|
| 43 |
+
"json_file": the path to the COCO format annotation
|
| 44 |
+
Or it must be in detectron2's standard dataset format
|
| 45 |
+
so it can be converted to COCO format automatically.
|
| 46 |
+
tasks (tuple[str]): tasks that can be evaluated under the given
|
| 47 |
+
configuration. A task is one of "bbox", "segm", "keypoints".
|
| 48 |
+
By default, will infer this automatically from predictions.
|
| 49 |
+
distributed (True): if True, will collect results from all ranks and run evaluation
|
| 50 |
+
in the main process.
|
| 51 |
+
Otherwise, will only evaluate the results in the current process.
|
| 52 |
+
output_dir (str): optional, an output directory to dump all
|
| 53 |
+
results predicted on the dataset. The dump contains two files:
|
| 54 |
+
1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
|
| 55 |
+
contains all the results in the format they are produced by the model.
|
| 56 |
+
2. "coco_instances_results.json" a json file in COCO's result format.
|
| 57 |
+
max_dets_per_image (int): limit on the maximum number of detections per image.
|
| 58 |
+
By default in COCO, this limit is to 100, but this can be customized
|
| 59 |
+
to be greater, as is needed in evaluation metrics AP fixed and AP pool
|
| 60 |
+
(see https://arxiv.org/pdf/2102.01066.pdf)
|
| 61 |
+
This doesn't affect keypoint evaluation.
|
| 62 |
+
use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
|
| 63 |
+
Although the results should be very close to the official implementation in COCO
|
| 64 |
+
API, it is still recommended to compute results with the official API for use in
|
| 65 |
+
papers. The faster implementation also uses more RAM.
|
| 66 |
+
kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
|
| 67 |
+
See http://cocodataset.org/#keypoints-eval
|
| 68 |
+
When empty, it will use the defaults in COCO.
|
| 69 |
+
Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
|
| 70 |
+
allow_cached_coco (bool): Whether to use cached coco json from previous validation
|
| 71 |
+
runs. You should set this to False if you need to use different validation data.
|
| 72 |
+
Defaults to True.
|
| 73 |
+
"""
|
| 74 |
+
self._logger = logging.getLogger(__name__)
|
| 75 |
+
self._distributed = distributed
|
| 76 |
+
self._output_dir = output_dir
|
| 77 |
+
self._gt_json = COCO(gt_json)
|
| 78 |
+
|
| 79 |
+
def reset(self):
|
| 80 |
+
self._gen_captions = []
|
| 81 |
+
self._image_ids = []
|
| 82 |
+
|
| 83 |
+
def process(self, inputs, outputs):
|
| 84 |
+
"""
|
| 85 |
+
Args:
|
| 86 |
+
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
|
| 87 |
+
It is a list of dict. Each dict corresponds to an image and
|
| 88 |
+
contains keys like "height", "width", "file_name", "image_id".
|
| 89 |
+
outputs: the outputs of a COCO model. It is a list of dicts with key
|
| 90 |
+
"instances" that contains :class:`Instances`.
|
| 91 |
+
"""
|
| 92 |
+
for output in outputs:
|
| 93 |
+
self._image_ids.append(output['image_id'])
|
| 94 |
+
self._gen_captions.append(output['captioning_text'])
|
| 95 |
+
|
| 96 |
+
def evaluate(self, img_ids=None):
|
| 97 |
+
"""
|
| 98 |
+
Args:
|
| 99 |
+
img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
if self._distributed:
|
| 103 |
+
comm.synchronize()
|
| 104 |
+
def gather(x, move=False):
|
| 105 |
+
x = comm.gather(x)
|
| 106 |
+
x = list(itertools.chain(*x))
|
| 107 |
+
if move:
|
| 108 |
+
x = [xx.to(self._gen_captions[0].device) for xx in x]
|
| 109 |
+
return x
|
| 110 |
+
gen_captions = gather(self._gen_captions)
|
| 111 |
+
image_ids = gather(self._image_ids)
|
| 112 |
+
if not comm.is_main_process():
|
| 113 |
+
return {}
|
| 114 |
+
else:
|
| 115 |
+
gen_captions = self._gen_captions
|
| 116 |
+
image_ids = self._image_ids
|
| 117 |
+
|
| 118 |
+
assert len(gen_captions) == len(image_ids)
|
| 119 |
+
pred_captions = [{"image_id": image_id, "caption": gen_caption} for image_id, gen_caption in zip(image_ids, gen_captions)]
|
| 120 |
+
pred_pth = os.path.join(self._output_dir, 'results.json')
|
| 121 |
+
json.dump(pred_captions, open(pred_pth, "w"))
|
| 122 |
+
|
| 123 |
+
gt_captions = self._gt_json
|
| 124 |
+
pred_captions = gt_captions.loadRes(pred_pth)
|
| 125 |
+
|
| 126 |
+
cocoEval = COCOEvalCap(gt_captions, pred_captions)
|
| 127 |
+
cocoEval.params['image_id'] = pred_captions.getImgIds()
|
| 128 |
+
cocoEval.evaluate()
|
| 129 |
+
return cocoEval.eval
|
datasets/evaluation/classification_evaluation.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# --------------------------------------------------------
|
| 3 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
| 4 |
+
# Copyright (c) 2022 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Modified by Xueyan Zou ([email protected])
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
| 13 |
+
|
| 14 |
+
from utilities.misc import AverageMeter
|
| 15 |
+
from utilities.distributed import get_world_size
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def accuracy(output, target, topk=(1,)):
|
| 20 |
+
"""Computes the precision@k for the specified values of k"""
|
| 21 |
+
if isinstance(output, list):
|
| 22 |
+
output = output[-1]
|
| 23 |
+
|
| 24 |
+
n_classes = output.size()[1]
|
| 25 |
+
maxk = min(max(topk), n_classes)
|
| 26 |
+
batch_size = target.size(0)
|
| 27 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 28 |
+
pred = pred.t()
|
| 29 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
| 30 |
+
|
| 31 |
+
res = []
|
| 32 |
+
for k in topk:
|
| 33 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
| 34 |
+
res.append(correct_k.mul_(100.0 / batch_size).item())
|
| 35 |
+
return res
|
| 36 |
+
|
| 37 |
+
class ClassificationEvaluator(DatasetEvaluator):
|
| 38 |
+
def __init__(self, *args):
|
| 39 |
+
self.top1 = AverageMeter()
|
| 40 |
+
self.top5 = AverageMeter()
|
| 41 |
+
self._logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
def reset(self):
|
| 44 |
+
self.top1.reset()
|
| 45 |
+
self.top5.reset()
|
| 46 |
+
|
| 47 |
+
def process(self, inputs, outputs):
|
| 48 |
+
logits = torch.stack([o['pred_class'] for o in outputs])
|
| 49 |
+
y = torch.tensor([t['class_id'] for t in inputs], device=logits.device)
|
| 50 |
+
prec1, prec5 = accuracy(logits, y, (1, 5))
|
| 51 |
+
self.top1.update(prec1, y.size(0))
|
| 52 |
+
self.top5.update(prec5, y.size(0))
|
| 53 |
+
|
| 54 |
+
def evaluate(self):
|
| 55 |
+
if get_world_size() > 1:
|
| 56 |
+
tmp_tensor = torch.tensor(
|
| 57 |
+
[self.top1.sum, self.top5.sum, self.top1.count],
|
| 58 |
+
device=torch.cuda.current_device()
|
| 59 |
+
)
|
| 60 |
+
torch.distributed.all_reduce(
|
| 61 |
+
tmp_tensor, torch.distributed.ReduceOp.SUM
|
| 62 |
+
)
|
| 63 |
+
top1_sum, top5_sum, count = tmp_tensor.tolist()
|
| 64 |
+
else:
|
| 65 |
+
top1_sum = self.top1.sum
|
| 66 |
+
top5_sum = self.top5.sum
|
| 67 |
+
count = self.top1.count
|
| 68 |
+
|
| 69 |
+
results = {}
|
| 70 |
+
scores = {
|
| 71 |
+
'top1': top1_sum / count,
|
| 72 |
+
"top5": top5_sum / count
|
| 73 |
+
}
|
| 74 |
+
results['class'] = scores
|
| 75 |
+
self._logger.info(results)
|
| 76 |
+
return results
|
datasets/evaluation/grounding_evaluation.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
| 3 |
+
# Copyright (c) 2022 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Modified by Xueyan Zou ([email protected])
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import logging
|
| 8 |
+
import torch
|
| 9 |
+
from torchvision.ops import box_iou
|
| 10 |
+
|
| 11 |
+
from detectron2.structures import BoxMode
|
| 12 |
+
from detectron2.data import MetadataCatalog
|
| 13 |
+
from detectron2.utils.comm import all_gather, is_main_process, synchronize
|
| 14 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
| 15 |
+
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
import numpy as np
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
import copy
|
| 21 |
+
|
| 22 |
+
class GroundingEvaluator(DatasetEvaluator):
|
| 23 |
+
"""
|
| 24 |
+
Evaluate grounding segmentation metrics.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
dataset_name,
|
| 30 |
+
compute_box=False,
|
| 31 |
+
distributed=True,
|
| 32 |
+
):
|
| 33 |
+
self._logger = logging.getLogger(__name__)
|
| 34 |
+
self._dataset_name = dataset_name
|
| 35 |
+
self._distributed = distributed
|
| 36 |
+
self._cpu_device = torch.device("cpu")
|
| 37 |
+
self._compute_box = compute_box
|
| 38 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 39 |
+
|
| 40 |
+
def reset(self):
|
| 41 |
+
self.cum_I = 0
|
| 42 |
+
self.cum_U = 0
|
| 43 |
+
self.mIoU = 0
|
| 44 |
+
self.mDice = 0
|
| 45 |
+
self.cum_mean_area = 0
|
| 46 |
+
self.eval_seg_iou_list = [.5, .6, .7, .8, .9]
|
| 47 |
+
self.seg_correct = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
|
| 48 |
+
self.seg_total = 0
|
| 49 |
+
self.instance_results = []
|
| 50 |
+
if self._compute_box:
|
| 51 |
+
self.mIoU_box = 0
|
| 52 |
+
self.seg_correct_box = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def computeIoU(pred_seg, gd_seg):
|
| 56 |
+
I = (pred_seg & gd_seg)
|
| 57 |
+
U = (pred_seg | gd_seg)
|
| 58 |
+
return I, U
|
| 59 |
+
|
| 60 |
+
def get_metadata(self, _input):
|
| 61 |
+
"""
|
| 62 |
+
Extracts and returns specific metadata from the input dictionary.
|
| 63 |
+
|
| 64 |
+
Parameters:
|
| 65 |
+
_input (dict): A dictionary containing keys like 'file_name', 'image_id', and 'grounding_info'.
|
| 66 |
+
The 'grounding_info' is a list of dictionaries with keys like 'area', 'iscrowd', etc.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
dict: A dictionary containing filtered metadata.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
_input = copy.deepcopy(_input)
|
| 73 |
+
|
| 74 |
+
selected_input_keys = ['file_name', 'image_id', 'grounding_info']
|
| 75 |
+
selected_grounding_info_keys = ['area', 'mask_file', 'iscrowd', 'image_id', 'category_id', 'id', 'file_name', 'split', 'ann_id', 'ref_id']
|
| 76 |
+
|
| 77 |
+
filtered_input = {key: _input[key] for key in selected_input_keys if key in _input}
|
| 78 |
+
|
| 79 |
+
# Check if grounding_info is present and is a list
|
| 80 |
+
if 'grounding_info' in filtered_input and isinstance(filtered_input['grounding_info'], list):
|
| 81 |
+
# Filter each grounding_info dictionary
|
| 82 |
+
filtered_input['grounding_info'] = [
|
| 83 |
+
{key: info[key] for key in selected_grounding_info_keys if key in info}
|
| 84 |
+
for info in filtered_input['grounding_info']
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
return filtered_input
|
| 88 |
+
|
| 89 |
+
def process(self, inputs, outputs):
|
| 90 |
+
for input, output in zip(inputs, outputs):
|
| 91 |
+
pred = output['grounding_mask'].sigmoid() > 0.5
|
| 92 |
+
# # save pixel probability
|
| 93 |
+
# prob = output['grounding_mask'].sigmoid().cpu().numpy()[0] * 255
|
| 94 |
+
# pred_file = input['file_name'].split('.')[0].replace('test/', 'test_pred/') + '_' + input['groundings']['texts'][0].replace(' ', '+') + '.png'
|
| 95 |
+
# if not os.path.exists('/'.join(pred_file.split('/')[:-1])):
|
| 96 |
+
# os.makedirs('/'.join(pred_file.split('/')[:-1]), exist_ok=True)
|
| 97 |
+
# plt.imsave(pred_file,
|
| 98 |
+
# prob.astype(np.uint8), cmap='gray')
|
| 99 |
+
|
| 100 |
+
gt = input['groundings']['masks'].bool()
|
| 101 |
+
bsi = len(pred)
|
| 102 |
+
I, U = self.computeIoU(pred, gt)
|
| 103 |
+
self.cum_I += I.sum().cpu()
|
| 104 |
+
self.cum_U += U.sum().cpu()
|
| 105 |
+
IoU = I.reshape(bsi,-1).sum(-1)*1.0 / (U.reshape(bsi,-1).sum(-1) + 1e-6)
|
| 106 |
+
self.mIoU += IoU.sum().cpu()
|
| 107 |
+
# Add Dice score in eval
|
| 108 |
+
Dice = I.reshape(bsi,-1).sum(-1)*2.0 / (gt.reshape(bsi,-1).sum(-1) + pred.reshape(bsi,-1).sum(-1) + 1e-6)
|
| 109 |
+
self.mDice += Dice.sum().cpu()
|
| 110 |
+
self.cum_mean_area += ((gt.reshape(bsi,-1).sum(-1) + pred.reshape(bsi,-1).sum(-1)) / 2.0).sum().cpu()
|
| 111 |
+
|
| 112 |
+
if self._compute_box:
|
| 113 |
+
pred_box = BoxMode.convert(output['grounding_box'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
|
| 114 |
+
gt_box = BoxMode.convert(input['groundings']['boxes'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS).cpu()
|
| 115 |
+
IoU_box = box_iou(pred_box, gt_box).diagonal()
|
| 116 |
+
self.mIoU_box += IoU_box.sum()
|
| 117 |
+
|
| 118 |
+
for idx in range(len(self.eval_seg_iou_list)):
|
| 119 |
+
eval_seg_iou = self.eval_seg_iou_list[idx]
|
| 120 |
+
self.seg_correct[idx] += (IoU >= eval_seg_iou).sum().cpu()
|
| 121 |
+
if self._compute_box:
|
| 122 |
+
self.seg_correct_box[idx] += (IoU_box >= eval_seg_iou).sum().cpu()
|
| 123 |
+
self.seg_total += bsi
|
| 124 |
+
|
| 125 |
+
instance_result = {
|
| 126 |
+
'metadata': self.get_metadata(input),
|
| 127 |
+
'IoU': IoU.cpu().numpy().tolist(),
|
| 128 |
+
'Dice': Dice.cpu().numpy().tolist(),
|
| 129 |
+
'I': I.sum(dim=(1, 2)).cpu().numpy().tolist(),
|
| 130 |
+
'U': U.sum(dim=(1, 2)).cpu().numpy().tolist(),
|
| 131 |
+
'IoU_box': IoU_box.cpu().numpy().tolist() if self._compute_box else '',
|
| 132 |
+
'pred_area': pred.reshape(bsi,-1).sum(-1).cpu().numpy().tolist(),
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
iou_len = IoU.shape[0]
|
| 136 |
+
grounding_info_len = len(self.get_metadata(input)['grounding_info'])
|
| 137 |
+
assert iou_len == grounding_info_len, f'Number of IoU scores ({iou_len}) and grounding info ({grounding_info_len}) do not match.'
|
| 138 |
+
self.instance_results.append(instance_result)
|
| 139 |
+
|
| 140 |
+
def evaluate(self):
|
| 141 |
+
if self._distributed:
|
| 142 |
+
synchronize()
|
| 143 |
+
self.cum_I = torch.stack(all_gather(self.cum_I)).sum()
|
| 144 |
+
self.cum_U = torch.stack(all_gather(self.cum_U)).sum()
|
| 145 |
+
self.mIoU = torch.stack(all_gather(self.mIoU)).sum()
|
| 146 |
+
self.mDice = torch.stack(all_gather(self.mDice)).sum()
|
| 147 |
+
self.cum_mean_area = torch.stack(all_gather(self.cum_mean_area)).sum()
|
| 148 |
+
self.seg_correct = torch.stack(all_gather(self.seg_correct)).sum(0)
|
| 149 |
+
self.seg_total = sum(all_gather(self.seg_total))
|
| 150 |
+
self.instance_results = sum(all_gather(self.instance_results), [])
|
| 151 |
+
if self._compute_box:
|
| 152 |
+
self.mIoU_box = torch.stack(all_gather(self.mIoU_box)).sum()
|
| 153 |
+
self.seg_correct_box = torch.stack(all_gather(self.seg_correct_box)).sum(0)
|
| 154 |
+
if not is_main_process():
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
results = {}
|
| 158 |
+
for idx in range(len(self.eval_seg_iou_list)):
|
| 159 |
+
result_str = 'precision@{}'.format(self.eval_seg_iou_list[idx])
|
| 160 |
+
results[result_str] = (self.seg_correct[idx]*100 / self.seg_total).item()
|
| 161 |
+
results['cIoU'] = (self.cum_I*100./self.cum_U).item()
|
| 162 |
+
results['mIoU'] = (self.mIoU*100./self.seg_total).item()
|
| 163 |
+
results['cDice'] = (self.cum_I*100./self.cum_mean_area).item()
|
| 164 |
+
results['mDice'] = (self.mDice*100./self.seg_total).item()
|
| 165 |
+
|
| 166 |
+
if self._compute_box:
|
| 167 |
+
for idx in range(len(self.eval_seg_iou_list)):
|
| 168 |
+
result_str = 'precisionB@{}'.format(self.eval_seg_iou_list[idx])
|
| 169 |
+
results[result_str] = (self.seg_correct_box[idx]*100 / self.seg_total).item()
|
| 170 |
+
results['mBIoU'] = (self.mIoU_box*100./self.seg_total).item()
|
| 171 |
+
|
| 172 |
+
self._logger.info(results)
|
| 173 |
+
return {'grounding': {'scores': results, 'instance_results': self.instance_results}}
|
datasets/evaluation/instance_evaluation.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import contextlib
|
| 3 |
+
import copy
|
| 4 |
+
import io
|
| 5 |
+
import itertools
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
import pickle
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
import pycocotools.mask as mask_util
|
| 13 |
+
import torch
|
| 14 |
+
from pycocotools.coco import COCO
|
| 15 |
+
from pycocotools.cocoeval import COCOeval
|
| 16 |
+
from tabulate import tabulate
|
| 17 |
+
|
| 18 |
+
import detectron2.utils.comm as comm
|
| 19 |
+
from detectron2.config import CfgNode
|
| 20 |
+
from detectron2.data import MetadataCatalog
|
| 21 |
+
from detectron2.data.datasets.coco import convert_to_coco_json
|
| 22 |
+
from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco
|
| 23 |
+
from detectron2.evaluation.fast_eval_api import COCOeval_opt
|
| 24 |
+
from detectron2.structures import Boxes, BoxMode, pairwise_iou
|
| 25 |
+
from detectron2.utils.file_io import PathManager
|
| 26 |
+
from detectron2.utils.logger import create_small_table
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# modified from COCOEvaluator for instance segmetnat
|
| 30 |
+
class InstanceSegEvaluator(COCOEvaluator):
|
| 31 |
+
"""
|
| 32 |
+
Evaluate AR for object proposals, AP for instance detection/segmentation, AP
|
| 33 |
+
for keypoint detection outputs using COCO's metrics.
|
| 34 |
+
See http://cocodataset.org/#detection-eval and
|
| 35 |
+
http://cocodataset.org/#keypoints-eval to understand its metrics.
|
| 36 |
+
The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
|
| 37 |
+
the metric cannot be computed (e.g. due to no predictions made).
|
| 38 |
+
|
| 39 |
+
In addition to COCO, this evaluator is able to support any bounding box detection,
|
| 40 |
+
instance segmentation, or keypoint detection dataset.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def _eval_predictions(self, predictions, img_ids=None):
|
| 44 |
+
"""
|
| 45 |
+
Evaluate predictions. Fill self._results with the metrics of the tasks.
|
| 46 |
+
"""
|
| 47 |
+
self._logger.info("Preparing results for COCO format ...")
|
| 48 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
| 49 |
+
tasks = self._tasks or self._tasks_from_predictions(coco_results)
|
| 50 |
+
|
| 51 |
+
# unmap the category ids for COCO
|
| 52 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
| 53 |
+
dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
|
| 54 |
+
# all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
|
| 55 |
+
# num_classes = len(all_contiguous_ids)
|
| 56 |
+
# assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
|
| 57 |
+
|
| 58 |
+
reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
|
| 59 |
+
for result in coco_results:
|
| 60 |
+
category_id = result["category_id"]
|
| 61 |
+
# assert category_id < num_classes, (
|
| 62 |
+
# f"A prediction has class={category_id}, "
|
| 63 |
+
# f"but the dataset only has {num_classes} classes and "
|
| 64 |
+
# f"predicted class id should be in [0, {num_classes - 1}]."
|
| 65 |
+
# )
|
| 66 |
+
assert category_id in reverse_id_mapping, (
|
| 67 |
+
f"A prediction has class={category_id}, "
|
| 68 |
+
f"but the dataset only has class ids in {dataset_id_to_contiguous_id}."
|
| 69 |
+
)
|
| 70 |
+
result["category_id"] = reverse_id_mapping[category_id]
|
| 71 |
+
|
| 72 |
+
if self._output_dir:
|
| 73 |
+
file_path = os.path.join(self._output_dir, "coco_instances_results.json")
|
| 74 |
+
self._logger.info("Saving results to {}".format(file_path))
|
| 75 |
+
with PathManager.open(file_path, "w") as f:
|
| 76 |
+
f.write(json.dumps(coco_results))
|
| 77 |
+
f.flush()
|
| 78 |
+
|
| 79 |
+
if not self._do_evaluation:
|
| 80 |
+
self._logger.info("Annotations are not available for evaluation.")
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
self._logger.info(
|
| 84 |
+
"Evaluating predictions with {} COCO API...".format(
|
| 85 |
+
"unofficial" if self._use_fast_impl else "official"
|
| 86 |
+
)
|
| 87 |
+
)
|
| 88 |
+
for task in sorted(tasks):
|
| 89 |
+
assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
|
| 90 |
+
coco_eval = (
|
| 91 |
+
_evaluate_predictions_on_coco(
|
| 92 |
+
self._coco_api,
|
| 93 |
+
coco_results,
|
| 94 |
+
task,
|
| 95 |
+
kpt_oks_sigmas=self._kpt_oks_sigmas,
|
| 96 |
+
use_fast_impl=self._use_fast_impl,
|
| 97 |
+
img_ids=img_ids,
|
| 98 |
+
max_dets_per_image=self._max_dets_per_image,
|
| 99 |
+
)
|
| 100 |
+
if len(coco_results) > 0
|
| 101 |
+
else None # cocoapi does not handle empty results very well
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
res = self._derive_coco_results(
|
| 105 |
+
coco_eval, task, class_names=self._metadata.get("thing_classes")
|
| 106 |
+
)
|
| 107 |
+
self._results[task] = res
|
datasets/evaluation/interactive_evaluation.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision.ops import box_iou
|
| 8 |
+
|
| 9 |
+
from detectron2.structures import BoxMode
|
| 10 |
+
from detectron2.data import MetadataCatalog
|
| 11 |
+
from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize
|
| 12 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class InteractiveEvaluator(DatasetEvaluator):
|
| 16 |
+
"""
|
| 17 |
+
Evaluate point interactive IoU metrics.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
dataset_name,
|
| 23 |
+
output_dir,
|
| 24 |
+
max_clicks=20,
|
| 25 |
+
iou_iter=1,
|
| 26 |
+
compute_box=False,
|
| 27 |
+
distributed=True,
|
| 28 |
+
):
|
| 29 |
+
self._logger = logging.getLogger(__name__)
|
| 30 |
+
self._dataset_name = dataset_name
|
| 31 |
+
self._distributed = distributed
|
| 32 |
+
self._cpu_device = torch.device("cpu")
|
| 33 |
+
self._output_dir = output_dir
|
| 34 |
+
|
| 35 |
+
self.max_clicks = max_clicks
|
| 36 |
+
self.iou_iter = iou_iter
|
| 37 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 38 |
+
|
| 39 |
+
def reset(self):
|
| 40 |
+
self.iou_list = []
|
| 41 |
+
self.num_samples = 0
|
| 42 |
+
self.all_ious = [0.5, 0.8, 0.85, 0.9]
|
| 43 |
+
|
| 44 |
+
def process(self, inputs, outputs):
|
| 45 |
+
self.iou_list += [o['mask_iou'] for o in outputs]
|
| 46 |
+
self.num_samples += len(outputs)
|
| 47 |
+
|
| 48 |
+
def compute_noc(self):
|
| 49 |
+
def _get_noc(iou_arr, iou_thr):
|
| 50 |
+
vals = iou_arr >= iou_thr
|
| 51 |
+
return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks
|
| 52 |
+
|
| 53 |
+
noc_list = {}
|
| 54 |
+
for iou_thr in self.all_ious:
|
| 55 |
+
scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list]
|
| 56 |
+
noc_list[str(iou_thr)] = scores_arr
|
| 57 |
+
|
| 58 |
+
iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1]
|
| 59 |
+
noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()}
|
| 60 |
+
|
| 61 |
+
if self._distributed:
|
| 62 |
+
num_samples = sum(all_gather(self.num_samples))
|
| 63 |
+
noc_list_sum_gather = all_gather(noc_list_sum)
|
| 64 |
+
iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu())
|
| 65 |
+
|
| 66 |
+
noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]}
|
| 67 |
+
for nlg in noc_list_sum_gather:
|
| 68 |
+
for key, value in nlg.items():
|
| 69 |
+
noc_list_sum[key] += value
|
| 70 |
+
|
| 71 |
+
pred_noc = {}
|
| 72 |
+
if self._distributed and (not is_main_process()):
|
| 73 |
+
return pred_noc
|
| 74 |
+
|
| 75 |
+
for key, value in noc_list_sum.items():
|
| 76 |
+
pred_noc[key] = value / num_samples
|
| 77 |
+
|
| 78 |
+
pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples
|
| 79 |
+
return pred_noc
|
| 80 |
+
|
| 81 |
+
def evaluate(self):
|
| 82 |
+
pred_noc = self.compute_noc()
|
| 83 |
+
|
| 84 |
+
if self._distributed and (not is_main_process()):
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
def draw_iou_curve(iou_list, save_dir):
|
| 88 |
+
iou_list = torch.stack(iou_list, dim=0)
|
| 89 |
+
iou_list = iou_list.mean(dim=0).cpu().numpy()
|
| 90 |
+
# draw iou curve, with x-axis as number of clicks, y-axis as iou using matplotlib
|
| 91 |
+
import matplotlib.pyplot as plt
|
| 92 |
+
plt.figure()
|
| 93 |
+
plt.plot(range(1, self.max_clicks+1), iou_list)
|
| 94 |
+
plt.xlabel('Number of clicks')
|
| 95 |
+
plt.ylabel('IoU')
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# create directory if not exist
|
| 99 |
+
import os
|
| 100 |
+
output_dir = os.path.join(save_dir, 'iou_by_clicks')
|
| 101 |
+
if not os.path.exists(output_dir):
|
| 102 |
+
os.makedirs(output_dir)
|
| 103 |
+
|
| 104 |
+
# get current time and format in 10 digits
|
| 105 |
+
import time
|
| 106 |
+
current_time = time.time()
|
| 107 |
+
current_time = int(current_time)
|
| 108 |
+
current_time = str(current_time)
|
| 109 |
+
|
| 110 |
+
# save iou curve
|
| 111 |
+
plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time)))
|
| 112 |
+
|
| 113 |
+
draw_iou_curve(self.iou_list, self._output_dir)
|
| 114 |
+
results = {}
|
| 115 |
+
for idx in range(len(self.all_ious)):
|
| 116 |
+
result_str = 'noc@{}'.format(self.all_ious[idx])
|
| 117 |
+
results[result_str] = pred_noc[str(self.all_ious[idx])]
|
| 118 |
+
|
| 119 |
+
results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter']
|
| 120 |
+
|
| 121 |
+
self._logger.info(results)
|
| 122 |
+
return {'interactive': results}
|
datasets/evaluation/panoptic_evaluation.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import contextlib
|
| 3 |
+
import io
|
| 4 |
+
import itertools
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
import tempfile
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
from typing import Optional
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from tabulate import tabulate
|
| 14 |
+
|
| 15 |
+
from detectron2.data import MetadataCatalog
|
| 16 |
+
from detectron2.utils import comm
|
| 17 |
+
from detectron2.utils.file_io import PathManager
|
| 18 |
+
|
| 19 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class COCOPanopticEvaluator(DatasetEvaluator):
|
| 25 |
+
"""
|
| 26 |
+
Evaluate Panoptic Quality metrics on COCO using PanopticAPI.
|
| 27 |
+
It saves panoptic segmentation prediction in `output_dir`
|
| 28 |
+
|
| 29 |
+
It contains a synchronize call and has to be called from all workers.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, dataset_name: str, output_dir: Optional[str] = None):
|
| 33 |
+
"""
|
| 34 |
+
Args:
|
| 35 |
+
dataset_name: name of the dataset
|
| 36 |
+
output_dir: output directory to save results for evaluation.
|
| 37 |
+
"""
|
| 38 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
| 39 |
+
self._thing_contiguous_id_to_dataset_id = {
|
| 40 |
+
v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
|
| 41 |
+
}
|
| 42 |
+
self._stuff_contiguous_id_to_dataset_id = {
|
| 43 |
+
v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items()
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
self._output_dir = output_dir
|
| 47 |
+
if self._output_dir is not None:
|
| 48 |
+
PathManager.mkdirs(self._output_dir)
|
| 49 |
+
|
| 50 |
+
def reset(self):
|
| 51 |
+
self._predictions = []
|
| 52 |
+
|
| 53 |
+
def _convert_category_id(self, segment_info):
|
| 54 |
+
isthing = segment_info.pop("isthing", None)
|
| 55 |
+
if isthing is None:
|
| 56 |
+
# the model produces panoptic category id directly. No more conversion needed
|
| 57 |
+
return segment_info
|
| 58 |
+
if isthing is True:
|
| 59 |
+
segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[
|
| 60 |
+
segment_info["category_id"]
|
| 61 |
+
]
|
| 62 |
+
else:
|
| 63 |
+
segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[
|
| 64 |
+
segment_info["category_id"]
|
| 65 |
+
]
|
| 66 |
+
return segment_info
|
| 67 |
+
|
| 68 |
+
def process(self, inputs, outputs):
|
| 69 |
+
from panopticapi.utils import id2rgb
|
| 70 |
+
|
| 71 |
+
for input, output in zip(inputs, outputs):
|
| 72 |
+
panoptic_img, segments_info = output["panoptic_seg"]
|
| 73 |
+
panoptic_img = panoptic_img.cpu().numpy()
|
| 74 |
+
if segments_info is None:
|
| 75 |
+
# If "segments_info" is None, we assume "panoptic_img" is a
|
| 76 |
+
# H*W int32 image storing the panoptic_id in the format of
|
| 77 |
+
# category_id * label_divisor + instance_id. We reserve -1 for
|
| 78 |
+
# VOID label, and add 1 to panoptic_img since the official
|
| 79 |
+
# evaluation script uses 0 for VOID label.
|
| 80 |
+
label_divisor = self._metadata.label_divisor
|
| 81 |
+
segments_info = []
|
| 82 |
+
for panoptic_label in np.unique(panoptic_img):
|
| 83 |
+
if panoptic_label == -1:
|
| 84 |
+
# VOID region.
|
| 85 |
+
continue
|
| 86 |
+
pred_class = panoptic_label // label_divisor
|
| 87 |
+
isthing = (
|
| 88 |
+
pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values()
|
| 89 |
+
)
|
| 90 |
+
segments_info.append(
|
| 91 |
+
{
|
| 92 |
+
"id": int(panoptic_label) + 1,
|
| 93 |
+
"category_id": int(pred_class),
|
| 94 |
+
"isthing": bool(isthing),
|
| 95 |
+
}
|
| 96 |
+
)
|
| 97 |
+
# Official evaluation script uses 0 for VOID label.
|
| 98 |
+
panoptic_img += 1
|
| 99 |
+
|
| 100 |
+
file_name = os.path.basename(input["file_name"])
|
| 101 |
+
file_name_png = os.path.splitext(file_name)[0] + ".png"
|
| 102 |
+
with io.BytesIO() as out:
|
| 103 |
+
Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
|
| 104 |
+
segments_info = [self._convert_category_id(x) for x in segments_info]
|
| 105 |
+
self._predictions.append(
|
| 106 |
+
{
|
| 107 |
+
"image_id": input["image_id"],
|
| 108 |
+
"file_name": file_name_png,
|
| 109 |
+
"png_string": out.getvalue(),
|
| 110 |
+
"segments_info": segments_info,
|
| 111 |
+
}
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def evaluate(self):
|
| 115 |
+
comm.synchronize()
|
| 116 |
+
|
| 117 |
+
self._predictions = comm.gather(self._predictions)
|
| 118 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
| 119 |
+
if not comm.is_main_process():
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
# PanopticApi requires local files
|
| 123 |
+
gt_json = PathManager.get_local_path(self._metadata.panoptic_json)
|
| 124 |
+
gt_folder = PathManager.get_local_path(self._metadata.panoptic_root)
|
| 125 |
+
|
| 126 |
+
with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
|
| 127 |
+
logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
|
| 128 |
+
for p in self._predictions:
|
| 129 |
+
with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
|
| 130 |
+
f.write(p.pop("png_string"))
|
| 131 |
+
|
| 132 |
+
with open(gt_json, "r") as f:
|
| 133 |
+
json_data = json.load(f)
|
| 134 |
+
json_data["annotations"] = self._predictions
|
| 135 |
+
|
| 136 |
+
output_dir = self._output_dir or pred_dir
|
| 137 |
+
predictions_json = os.path.join(output_dir, "predictions.json")
|
| 138 |
+
with PathManager.open(predictions_json, "w") as f:
|
| 139 |
+
f.write(json.dumps(json_data))
|
| 140 |
+
|
| 141 |
+
from panopticapi.evaluation import pq_compute
|
| 142 |
+
|
| 143 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 144 |
+
pq_res = pq_compute(
|
| 145 |
+
gt_json,
|
| 146 |
+
PathManager.get_local_path(predictions_json),
|
| 147 |
+
gt_folder=gt_folder,
|
| 148 |
+
pred_folder=pred_dir,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
res = {}
|
| 152 |
+
res["PQ"] = 100 * pq_res["All"]["pq"]
|
| 153 |
+
res["SQ"] = 100 * pq_res["All"]["sq"]
|
| 154 |
+
res["RQ"] = 100 * pq_res["All"]["rq"]
|
| 155 |
+
res["PQ_th"] = 100 * pq_res["Things"]["pq"]
|
| 156 |
+
res["SQ_th"] = 100 * pq_res["Things"]["sq"]
|
| 157 |
+
res["RQ_th"] = 100 * pq_res["Things"]["rq"]
|
| 158 |
+
res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
|
| 159 |
+
res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
|
| 160 |
+
res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
|
| 161 |
+
|
| 162 |
+
results = OrderedDict({"panoptic_seg": res})
|
| 163 |
+
_print_panoptic_results(pq_res)
|
| 164 |
+
|
| 165 |
+
return results
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _print_panoptic_results(pq_res):
|
| 169 |
+
headers = ["", "PQ", "SQ", "RQ", "#categories"]
|
| 170 |
+
data = []
|
| 171 |
+
for name in ["All", "Things", "Stuff"]:
|
| 172 |
+
row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
|
| 173 |
+
data.append(row)
|
| 174 |
+
table = tabulate(
|
| 175 |
+
data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
|
| 176 |
+
)
|
| 177 |
+
logger.info("Panoptic Evaluation Results:\n" + table)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
from detectron2.utils.logger import setup_logger
|
| 182 |
+
|
| 183 |
+
logger = setup_logger()
|
| 184 |
+
import argparse
|
| 185 |
+
|
| 186 |
+
parser = argparse.ArgumentParser()
|
| 187 |
+
parser.add_argument("--gt-json")
|
| 188 |
+
parser.add_argument("--gt-dir")
|
| 189 |
+
parser.add_argument("--pred-json")
|
| 190 |
+
parser.add_argument("--pred-dir")
|
| 191 |
+
args = parser.parse_args()
|
| 192 |
+
|
| 193 |
+
from panopticapi.evaluation import pq_compute
|
| 194 |
+
|
| 195 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 196 |
+
pq_res = pq_compute(
|
| 197 |
+
args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir
|
| 198 |
+
)
|
| 199 |
+
_print_panoptic_results(pq_res)
|
datasets/evaluation/retrieval_evaluation.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
| 3 |
+
# Copyright (c) 2022 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Modified by Xueyan Zou ([email protected]), Ziyi Dou ([email protected])
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import copy
|
| 8 |
+
import itertools
|
| 9 |
+
import logging
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
import torch
|
| 12 |
+
from pycocotools.cocoeval import COCOeval
|
| 13 |
+
|
| 14 |
+
import detectron2.utils.comm as comm
|
| 15 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from detectron2.evaluation.fast_eval_api import COCOeval_opt
|
| 19 |
+
except ImportError:
|
| 20 |
+
COCOeval_opt = COCOeval
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class RetrievalEvaluator(DatasetEvaluator):
|
| 24 |
+
"""
|
| 25 |
+
Evaluate AR for object proposals, AP for instance detection/segmentation, AP
|
| 26 |
+
for keypoint detection outputs using COCO's metrics.
|
| 27 |
+
See http://cocodataset.org/#detection-eval and
|
| 28 |
+
http://cocodataset.org/#keypoints-eval to understand its metrics.
|
| 29 |
+
The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
|
| 30 |
+
the metric cannot be computed (e.g. due to no predictions made).
|
| 31 |
+
In addition to COCO, this evaluator is able to support any bounding box detection,
|
| 32 |
+
instance segmentation, or keypoint detection dataset.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
dataset_name=None,
|
| 38 |
+
output_dir=None,
|
| 39 |
+
ensemble=False,
|
| 40 |
+
distributed=True,
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
dataset_name (str): name of the dataset to be evaluated.
|
| 45 |
+
It must have either the following corresponding metadata:
|
| 46 |
+
"json_file": the path to the COCO format annotation
|
| 47 |
+
Or it must be in detectron2's standard dataset format
|
| 48 |
+
so it can be converted to COCO format automatically.
|
| 49 |
+
tasks (tuple[str]): tasks that can be evaluated under the given
|
| 50 |
+
configuration. A task is one of "bbox", "segm", "keypoints".
|
| 51 |
+
By default, will infer this automatically from predictions.
|
| 52 |
+
distributed (True): if True, will collect results from all ranks and run evaluation
|
| 53 |
+
in the main process.
|
| 54 |
+
Otherwise, will only evaluate the results in the current process.
|
| 55 |
+
output_dir (str): optional, an output directory to dump all
|
| 56 |
+
results predicted on the dataset. The dump contains two files:
|
| 57 |
+
1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
|
| 58 |
+
contains all the results in the format they are produced by the model.
|
| 59 |
+
2. "coco_instances_results.json" a json file in COCO's result format.
|
| 60 |
+
max_dets_per_image (int): limit on the maximum number of detections per image.
|
| 61 |
+
By default in COCO, this limit is to 100, but this can be customized
|
| 62 |
+
to be greater, as is needed in evaluation metrics AP fixed and AP pool
|
| 63 |
+
(see https://arxiv.org/pdf/2102.01066.pdf)
|
| 64 |
+
This doesn't affect keypoint evaluation.
|
| 65 |
+
use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
|
| 66 |
+
Although the results should be very close to the official implementation in COCO
|
| 67 |
+
API, it is still recommended to compute results with the official API for use in
|
| 68 |
+
papers. The faster implementation also uses more RAM.
|
| 69 |
+
kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
|
| 70 |
+
See http://cocodataset.org/#keypoints-eval
|
| 71 |
+
When empty, it will use the defaults in COCO.
|
| 72 |
+
Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
|
| 73 |
+
allow_cached_coco (bool): Whether to use cached coco json from previous validation
|
| 74 |
+
runs. You should set this to False if you need to use different validation data.
|
| 75 |
+
Defaults to True.
|
| 76 |
+
"""
|
| 77 |
+
self._logger = logging.getLogger(__name__)
|
| 78 |
+
self._dataset_name = dataset_name
|
| 79 |
+
self._output_dir = output_dir
|
| 80 |
+
self._ensemble = ensemble
|
| 81 |
+
self._distributed = distributed
|
| 82 |
+
|
| 83 |
+
if 'p2i' in dataset_name:
|
| 84 |
+
self.mode = 'patch2image'
|
| 85 |
+
elif 'interactive2i' in dataset_name:
|
| 86 |
+
self.mode = 'interactive2image'
|
| 87 |
+
else:
|
| 88 |
+
self.mode = 'default'
|
| 89 |
+
|
| 90 |
+
def reset(self):
|
| 91 |
+
self._text_embeds = []
|
| 92 |
+
self._image_embeds = []
|
| 93 |
+
self._image_embeds2 = []
|
| 94 |
+
self._text_ids = []
|
| 95 |
+
self._image_ids = []
|
| 96 |
+
|
| 97 |
+
def process(self, inputs, outputs):
|
| 98 |
+
"""
|
| 99 |
+
Args:
|
| 100 |
+
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
|
| 101 |
+
It is a list of dict. Each dict corresponds to an image and
|
| 102 |
+
contains keys like "height", "width", "file_name", "image_id".
|
| 103 |
+
outputs: the outputs of a COCO model. It is a list of dicts with key
|
| 104 |
+
"instances" that contains :class:`Instances`.
|
| 105 |
+
"""
|
| 106 |
+
for output in outputs:
|
| 107 |
+
self._text_ids.extend(output['caption']['caption_ids'])
|
| 108 |
+
self._image_ids.append(output['caption']['image_ids'])
|
| 109 |
+
self._text_embeds.append(output['caption']['text_embeds'])
|
| 110 |
+
self._image_embeds.append(output['caption']['image_embeds'][0])
|
| 111 |
+
if self._ensemble:
|
| 112 |
+
self._image_embeds2.append(output['caption']['image_embeds'][1])
|
| 113 |
+
|
| 114 |
+
def evaluate(self, img_ids=None):
|
| 115 |
+
if self.mode == 'default':
|
| 116 |
+
return self.evaluate_default(img_ids)
|
| 117 |
+
elif self.mode in ['patch2image', 'interactive2image']:
|
| 118 |
+
return self.evaluate_p2i(img_ids)
|
| 119 |
+
else:
|
| 120 |
+
assert False, "Unknown mode for retrieval evaluation"
|
| 121 |
+
|
| 122 |
+
def evaluate_default(self, img_ids=None):
|
| 123 |
+
"""
|
| 124 |
+
Args:
|
| 125 |
+
img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
if self._distributed:
|
| 129 |
+
comm.synchronize()
|
| 130 |
+
def gather(x, move=False):
|
| 131 |
+
x = comm.gather(x)
|
| 132 |
+
x = list(itertools.chain(*x))
|
| 133 |
+
if move:
|
| 134 |
+
x = [xx.to(self._text_embeds[0].device) for xx in x]
|
| 135 |
+
return x
|
| 136 |
+
text_embeds = gather(self._text_embeds, move=True)
|
| 137 |
+
image_embeds = gather(self._image_embeds, move=True)
|
| 138 |
+
if self._ensemble:
|
| 139 |
+
image_embeds2 = gather(self._image_embeds2, move=True)
|
| 140 |
+
text_ids = gather(self._text_ids)
|
| 141 |
+
image_ids = gather(self._image_ids)
|
| 142 |
+
if not comm.is_main_process():
|
| 143 |
+
return {}
|
| 144 |
+
else:
|
| 145 |
+
text_embeds = self._text_embeds
|
| 146 |
+
image_embeds = self._image_embeds
|
| 147 |
+
if self._ensemble:
|
| 148 |
+
image_embeds2 = self._image_embeds2
|
| 149 |
+
text_ids = self._text_ids
|
| 150 |
+
image_ids = self._image_ids
|
| 151 |
+
if len(text_embeds) == 0:
|
| 152 |
+
self._logger.warning("[COCOCaptionEvaluator] Did not receive valid predictions.")
|
| 153 |
+
return {}
|
| 154 |
+
iids = torch.tensor(image_ids).view(-1).cuda()
|
| 155 |
+
tiids = torch.tensor(text_ids).view(-1).cuda()
|
| 156 |
+
image_embeds = torch.cat(image_embeds)
|
| 157 |
+
text_embeds = torch.cat(text_embeds)
|
| 158 |
+
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
| 159 |
+
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
| 160 |
+
scores = image_embeds @ text_embeds.t()
|
| 161 |
+
|
| 162 |
+
if self._ensemble:
|
| 163 |
+
image_embeds2 = torch.cat(image_embeds2)
|
| 164 |
+
image_embeds2 = image_embeds2 / image_embeds2.norm(dim=-1, keepdim=True)
|
| 165 |
+
scores2 = image_embeds2 @ text_embeds.t()
|
| 166 |
+
scores = scores2 * 0.5 + scores * 0.5
|
| 167 |
+
|
| 168 |
+
topk10 = scores.topk(10, dim=1)
|
| 169 |
+
topk5 = scores.topk(5, dim=1)
|
| 170 |
+
topk1 = scores.topk(1, dim=1)
|
| 171 |
+
topk10_iids = tiids[topk10.indices]
|
| 172 |
+
topk5_iids = tiids[topk5.indices]
|
| 173 |
+
topk1_iids = tiids[topk1.indices]
|
| 174 |
+
tr_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
|
| 175 |
+
tr_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
|
| 176 |
+
tr_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()
|
| 177 |
+
topk10 = scores.topk(10, dim=0)
|
| 178 |
+
topk5 = scores.topk(5, dim=0)
|
| 179 |
+
topk1 = scores.topk(1, dim=0)
|
| 180 |
+
topk10_iids = iids[topk10.indices]
|
| 181 |
+
topk5_iids = iids[topk5.indices]
|
| 182 |
+
topk1_iids = iids[topk1.indices]
|
| 183 |
+
ir_r10 = (tiids.unsqueeze(0) == topk10_iids).float().max(dim=0)[0].mean()
|
| 184 |
+
ir_r5 = (tiids.unsqueeze(0) == topk5_iids).float().max(dim=0)[0].mean()
|
| 185 |
+
ir_r1 = (tiids.unsqueeze(0) == topk1_iids).float().max(dim=0)[0].mean()
|
| 186 |
+
self._results = OrderedDict()
|
| 187 |
+
# Copy so the caller can do whatever with results
|
| 188 |
+
self._results['recall'] = {}
|
| 189 |
+
self._results['recall']['irtr'] = float("{:.3f}".format((ir_r1 + tr_r1).item() * 100))
|
| 190 |
+
self._results['recall']['ir1'] = float("{:.3f}".format(ir_r1.item() * 100))
|
| 191 |
+
self._results['recall']['ir5'] = float("{:.3f}".format(ir_r5.item() * 100))
|
| 192 |
+
self._results['recall']['ir10'] = float("{:.3f}".format(ir_r10.item() * 100))
|
| 193 |
+
self._results['recall']['tr1'] = float("{:.3f}".format(tr_r1.item() * 100))
|
| 194 |
+
self._results['recall']['tr5'] = float("{:.3f}".format(tr_r5.item() * 100))
|
| 195 |
+
self._results['recall']['tr10'] = float("{:.3f}".format(tr_r10.item() * 100))
|
| 196 |
+
self._logger.info(self._results)
|
| 197 |
+
return copy.deepcopy(self._results)
|
| 198 |
+
|
| 199 |
+
def evaluate_p2i(self, img_ids=None):
|
| 200 |
+
"""
|
| 201 |
+
Args:
|
| 202 |
+
img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
if self._distributed:
|
| 206 |
+
comm.synchronize()
|
| 207 |
+
def gather(x, move=False):
|
| 208 |
+
x = comm.gather(x)
|
| 209 |
+
x = list(itertools.chain(*x))
|
| 210 |
+
if move:
|
| 211 |
+
x = [xx.to(self._text_embeds[0].device) for xx in x]
|
| 212 |
+
return x
|
| 213 |
+
text_embeds = gather(self._text_embeds, move=True)
|
| 214 |
+
image_embeds = gather(self._image_embeds, move=True)
|
| 215 |
+
image_embeds2 = gather(self._image_embeds2, move=True)
|
| 216 |
+
text_ids = gather(self._text_ids)
|
| 217 |
+
image_ids = gather(self._image_ids)
|
| 218 |
+
if not comm.is_main_process():
|
| 219 |
+
return {}
|
| 220 |
+
else:
|
| 221 |
+
text_embeds = self._text_embeds
|
| 222 |
+
image_embeds = self._image_embeds
|
| 223 |
+
image_embeds2 = self._image_embeds2
|
| 224 |
+
text_ids = self._text_ids
|
| 225 |
+
image_ids = self._image_ids
|
| 226 |
+
|
| 227 |
+
if len(text_embeds) == 0:
|
| 228 |
+
self._logger.warning("[COCOCaptionEvaluator] Did not receive valid predictions.")
|
| 229 |
+
return {}
|
| 230 |
+
|
| 231 |
+
iids = torch.tensor(image_ids).view(-1).cuda()
|
| 232 |
+
tiids = torch.tensor(text_ids).view(-1).cuda()
|
| 233 |
+
image_embeds = torch.cat(image_embeds)
|
| 234 |
+
text_embeds = torch.cat(text_embeds)
|
| 235 |
+
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
| 236 |
+
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
| 237 |
+
|
| 238 |
+
image_embeds2 = torch.cat(image_embeds2)
|
| 239 |
+
image_embeds2 = image_embeds2 / image_embeds2.norm(dim=-1, keepdim=True)
|
| 240 |
+
|
| 241 |
+
# compute image to image retrieval
|
| 242 |
+
self._results = OrderedDict()
|
| 243 |
+
self._results['recall'] = {}
|
| 244 |
+
ii_scores = image_embeds2 @ image_embeds.t()
|
| 245 |
+
|
| 246 |
+
topk10 = ii_scores.topk(10, dim=1)
|
| 247 |
+
topk5 = ii_scores.topk(5, dim=1)
|
| 248 |
+
topk1 = ii_scores.topk(1, dim=1)
|
| 249 |
+
topk10_iids = iids[topk10.indices]
|
| 250 |
+
topk5_iids = iids[topk5.indices]
|
| 251 |
+
topk1_iids = iids[topk1.indices]
|
| 252 |
+
iir_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
|
| 253 |
+
iir_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
|
| 254 |
+
iir_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()
|
| 255 |
+
# Copy so the caller can do whatever with results
|
| 256 |
+
self._results['recall']['p2ir1'] = float("{:.3f}".format(iir_r1.item() * 100))
|
| 257 |
+
self._results['recall']['p2ir5'] = float("{:.3f}".format(iir_r5.item() * 100))
|
| 258 |
+
self._results['recall']['p2ir10'] = float("{:.3f}".format(iir_r10.item() * 100))
|
| 259 |
+
self._logger.info(self._results)
|
| 260 |
+
return copy.deepcopy(self._results)
|
datasets/evaluation/segmentation_evaluation.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
import pycocotools.mask as mask_util
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 13 |
+
from detectron2.utils.comm import all_gather, is_main_process
|
| 14 |
+
from detectron2.utils.file_io import PathManager
|
| 15 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
| 16 |
+
from utilities.distributed import synchronize
|
| 17 |
+
|
| 18 |
+
from ..semseg_loader import load_semseg
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SemSegEvaluator(DatasetEvaluator):
|
| 22 |
+
"""
|
| 23 |
+
Evaluate semantic segmentation metrics.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
dataset_name,
|
| 29 |
+
distributed=True,
|
| 30 |
+
output_dir=None,
|
| 31 |
+
*,
|
| 32 |
+
num_classes=None,
|
| 33 |
+
ignore_label=None,
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
dataset_name (str): name of the dataset to be evaluated.
|
| 38 |
+
distributed (bool): if True, will collect results from all ranks for evaluation.
|
| 39 |
+
Otherwise, will evaluate the results in the current process.
|
| 40 |
+
output_dir (str): an output directory to dump results.
|
| 41 |
+
num_classes, ignore_label: deprecated argument
|
| 42 |
+
"""
|
| 43 |
+
self._logger = logging.getLogger(__name__)
|
| 44 |
+
if num_classes is not None:
|
| 45 |
+
self._logger.warn(
|
| 46 |
+
"SemSegEvaluator(num_classes) is deprecated! It should be obtained from metadata."
|
| 47 |
+
)
|
| 48 |
+
if ignore_label is not None:
|
| 49 |
+
self._logger.warn(
|
| 50 |
+
"SemSegEvaluator(ignore_label) is deprecated! It should be obtained from metadata."
|
| 51 |
+
)
|
| 52 |
+
self._dataset_name = dataset_name
|
| 53 |
+
self._distributed = distributed
|
| 54 |
+
self._output_dir = output_dir
|
| 55 |
+
|
| 56 |
+
self._cpu_device = torch.device("cpu")
|
| 57 |
+
|
| 58 |
+
self.input_file_to_gt_file = {
|
| 59 |
+
dataset_record["file_name"]: dataset_record["sem_seg_file_name"]
|
| 60 |
+
for dataset_record in DatasetCatalog.get(dataset_name)
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 64 |
+
# Dict that maps contiguous training ids to COCO category ids
|
| 65 |
+
try:
|
| 66 |
+
c2d = meta.stuff_dataset_id_to_contiguous_id
|
| 67 |
+
self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()}
|
| 68 |
+
except AttributeError:
|
| 69 |
+
self._contiguous_id_to_dataset_id = None
|
| 70 |
+
self._class_names = meta.stuff_classes
|
| 71 |
+
self._class_offset = meta.class_offset if hasattr(meta, 'class_offset') else 0
|
| 72 |
+
self._num_classes = len(meta.stuff_classes)
|
| 73 |
+
self._semseg_loader = meta.semseg_loader if hasattr(meta, 'semseg_loader') else 'PIL'
|
| 74 |
+
|
| 75 |
+
if num_classes is not None:
|
| 76 |
+
assert self._num_classes == num_classes, f"{self._num_classes} != {num_classes}"
|
| 77 |
+
self._ignore_label = ignore_label if ignore_label is not None else meta.ignore_label
|
| 78 |
+
|
| 79 |
+
def reset(self):
|
| 80 |
+
self._conf_matrix = np.zeros((self._num_classes + 1, self._num_classes + 1), dtype=np.int64)
|
| 81 |
+
self._predictions = []
|
| 82 |
+
|
| 83 |
+
def process(self, inputs, outputs):
|
| 84 |
+
"""
|
| 85 |
+
Args:
|
| 86 |
+
inputs: the inputs to a model.
|
| 87 |
+
It is a list of dicts. Each dict corresponds to an image and
|
| 88 |
+
contains keys like "height", "width", "file_name".
|
| 89 |
+
outputs: the outputs of a model. It is either list of semantic segmentation predictions
|
| 90 |
+
(Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
|
| 91 |
+
segmentation prediction in the same format.
|
| 92 |
+
"""
|
| 93 |
+
for input, output in zip(inputs, outputs):
|
| 94 |
+
output = output["sem_seg"].argmax(dim=0).to(self._cpu_device)
|
| 95 |
+
pred = np.array(output, dtype=np.int)
|
| 96 |
+
|
| 97 |
+
with PathManager.open(self.input_file_to_gt_file[input["file_name"]], "rb") as f:
|
| 98 |
+
gt = load_semseg(f, self._semseg_loader) - self._class_offset
|
| 99 |
+
|
| 100 |
+
if isinstance(self._ignore_label, int):
|
| 101 |
+
ignore_label = self._ignore_label - self._class_offset
|
| 102 |
+
gt[gt == self._ignore_label] = self._num_classes
|
| 103 |
+
elif isinstance(self._ignore_label, list):
|
| 104 |
+
for ignore_label in self._ignore_label:
|
| 105 |
+
ignore_label = ignore_label - self._class_offset
|
| 106 |
+
gt[gt == ignore_label] = self._num_classes
|
| 107 |
+
|
| 108 |
+
self._conf_matrix += np.bincount(
|
| 109 |
+
(self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
|
| 110 |
+
minlength=self._conf_matrix.size,
|
| 111 |
+
).reshape(self._conf_matrix.shape)
|
| 112 |
+
|
| 113 |
+
self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
|
| 114 |
+
|
| 115 |
+
def evaluate(self):
|
| 116 |
+
"""
|
| 117 |
+
Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
|
| 118 |
+
|
| 119 |
+
* Mean intersection-over-union averaged across classes (mIoU)
|
| 120 |
+
* Frequency Weighted IoU (fwIoU)
|
| 121 |
+
* Mean pixel accuracy averaged across classes (mACC)
|
| 122 |
+
* Pixel Accuracy (pACC)
|
| 123 |
+
"""
|
| 124 |
+
if self._distributed:
|
| 125 |
+
synchronize()
|
| 126 |
+
conf_matrix_list = all_gather(self._conf_matrix)
|
| 127 |
+
self._predictions = all_gather(self._predictions)
|
| 128 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
| 129 |
+
if not is_main_process():
|
| 130 |
+
return
|
| 131 |
+
self._conf_matrix = np.zeros_like(self._conf_matrix)
|
| 132 |
+
for conf_matrix in conf_matrix_list:
|
| 133 |
+
self._conf_matrix += conf_matrix
|
| 134 |
+
|
| 135 |
+
if self._output_dir:
|
| 136 |
+
PathManager.mkdirs(self._output_dir)
|
| 137 |
+
file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
|
| 138 |
+
with PathManager.open(file_path, "w") as f:
|
| 139 |
+
f.write(json.dumps(self._predictions))
|
| 140 |
+
|
| 141 |
+
acc = np.full(self._num_classes, np.nan, dtype=np.float)
|
| 142 |
+
iou = np.full(self._num_classes, np.nan, dtype=np.float)
|
| 143 |
+
tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
|
| 144 |
+
pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
|
| 145 |
+
class_weights = pos_gt / np.sum(pos_gt)
|
| 146 |
+
pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
|
| 147 |
+
acc_valid = pos_gt > 0
|
| 148 |
+
acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
|
| 149 |
+
iou_valid = (pos_gt + pos_pred) > 0
|
| 150 |
+
union = pos_gt + pos_pred - tp
|
| 151 |
+
iou[acc_valid] = tp[acc_valid] / union[acc_valid]
|
| 152 |
+
macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
|
| 153 |
+
miou = np.sum(iou[acc_valid]) / np.sum(iou_valid)
|
| 154 |
+
fiou = np.sum(iou[acc_valid] * class_weights[acc_valid])
|
| 155 |
+
pacc = np.sum(tp) / np.sum(pos_gt)
|
| 156 |
+
|
| 157 |
+
res = {}
|
| 158 |
+
res["mIoU"] = 100 * miou
|
| 159 |
+
res["fwIoU"] = 100 * fiou
|
| 160 |
+
for i, name in enumerate(self._class_names):
|
| 161 |
+
res["IoU-{}".format(name)] = 100 * iou[i]
|
| 162 |
+
res["mACC"] = 100 * macc
|
| 163 |
+
res["pACC"] = 100 * pacc
|
| 164 |
+
for i, name in enumerate(self._class_names):
|
| 165 |
+
res["ACC-{}".format(name)] = 100 * acc[i]
|
| 166 |
+
|
| 167 |
+
if self._output_dir:
|
| 168 |
+
file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
|
| 169 |
+
with PathManager.open(file_path, "wb") as f:
|
| 170 |
+
torch.save(res, f)
|
| 171 |
+
results = OrderedDict({"sem_seg": res})
|
| 172 |
+
self._logger.info(results)
|
| 173 |
+
return results
|
| 174 |
+
|
| 175 |
+
def encode_json_sem_seg(self, sem_seg, input_file_name):
|
| 176 |
+
"""
|
| 177 |
+
Convert semantic segmentation to COCO stuff format with segments encoded as RLEs.
|
| 178 |
+
See http://cocodataset.org/#format-results
|
| 179 |
+
"""
|
| 180 |
+
json_list = []
|
| 181 |
+
for label in np.unique(sem_seg):
|
| 182 |
+
if self._contiguous_id_to_dataset_id is not None:
|
| 183 |
+
assert (
|
| 184 |
+
label in self._contiguous_id_to_dataset_id
|
| 185 |
+
), "Label {} is not in the metadata info for {}".format(label, self._dataset_name)
|
| 186 |
+
dataset_id = self._contiguous_id_to_dataset_id[label]
|
| 187 |
+
else:
|
| 188 |
+
dataset_id = int(label)
|
| 189 |
+
mask = (sem_seg == label).astype(np.uint8)
|
| 190 |
+
mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0]
|
| 191 |
+
mask_rle["counts"] = mask_rle["counts"].decode("utf-8")
|
| 192 |
+
json_list.append(
|
| 193 |
+
{"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle}
|
| 194 |
+
)
|
| 195 |
+
return json_list
|
datasets/refer.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'licheng'
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
This interface provides access to four datasets:
|
| 5 |
+
1) refclef
|
| 6 |
+
2) refcoco
|
| 7 |
+
3) refcoco+
|
| 8 |
+
4) refcocog
|
| 9 |
+
split by unc and google
|
| 10 |
+
|
| 11 |
+
The following API functions are defined:
|
| 12 |
+
REFER - REFER api class
|
| 13 |
+
getRefIds - get ref ids that satisfy given filter conditions.
|
| 14 |
+
getAnnIds - get ann ids that satisfy given filter conditions.
|
| 15 |
+
getImgIds - get image ids that satisfy given filter conditions.
|
| 16 |
+
getCatIds - get category ids that satisfy given filter conditions.
|
| 17 |
+
loadRefs - load refs with the specified ref ids.
|
| 18 |
+
loadAnns - load anns with the specified ann ids.
|
| 19 |
+
loadImgs - load images with the specified image ids.
|
| 20 |
+
loadCats - load category names with the specified category ids.
|
| 21 |
+
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
|
| 22 |
+
showRef - show image, segmentation or box of the referred object with the ref
|
| 23 |
+
getMask - get mask and area of the referred object given ref
|
| 24 |
+
showMask - show mask of the referred object given ref
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from doctest import REPORT_ONLY_FIRST_FAILURE
|
| 28 |
+
import sys
|
| 29 |
+
import os.path as osp
|
| 30 |
+
import json
|
| 31 |
+
import pickle
|
| 32 |
+
import time
|
| 33 |
+
import itertools
|
| 34 |
+
import skimage.io as io
|
| 35 |
+
import matplotlib.pyplot as plt
|
| 36 |
+
from matplotlib.collections import PatchCollection
|
| 37 |
+
from matplotlib.patches import Polygon, Rectangle
|
| 38 |
+
from pprint import pprint
|
| 39 |
+
import numpy as np
|
| 40 |
+
from pycocotools import mask
|
| 41 |
+
# import cv2
|
| 42 |
+
# from skimage.measure import label, regionprops
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class REFER:
|
| 46 |
+
def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
|
| 47 |
+
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
|
| 48 |
+
# also provide dataset name and splitBy information
|
| 49 |
+
# e.g., dataset = 'refcoco', splitBy = 'unc'
|
| 50 |
+
print('loading dataset {} into memory...'.format(dataset))
|
| 51 |
+
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
| 52 |
+
self.DATA_DIR = osp.join(data_root, dataset)
|
| 53 |
+
if dataset in ['refcoco', 'refcoco+', 'refcocog']:
|
| 54 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
|
| 55 |
+
elif dataset == 'refclef':
|
| 56 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
|
| 57 |
+
else:
|
| 58 |
+
print('No refer dataset is called [{}]'.format(dataset))
|
| 59 |
+
sys.exit()
|
| 60 |
+
|
| 61 |
+
# load refs from data/dataset/refs(dataset).json
|
| 62 |
+
tic = time.time()
|
| 63 |
+
ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
|
| 64 |
+
self.data = {}
|
| 65 |
+
self.data['dataset'] = dataset
|
| 66 |
+
self.data['refs'] = pickle.load(open(ref_file, 'rb'))
|
| 67 |
+
|
| 68 |
+
# load annotations from data/dataset/instances.json
|
| 69 |
+
instances_file = osp.join(self.DATA_DIR, 'instances.json')
|
| 70 |
+
instances = json.load(open(instances_file, 'r'))
|
| 71 |
+
self.data['images'] = instances['images']
|
| 72 |
+
self.data['annotations'] = instances['annotations']
|
| 73 |
+
self.data['categories'] = instances['categories']
|
| 74 |
+
|
| 75 |
+
# create index
|
| 76 |
+
self.createIndex()
|
| 77 |
+
print('DONE (t=%.2fs)'.format(time.time()-tic))
|
| 78 |
+
|
| 79 |
+
def createIndex(self):
|
| 80 |
+
# create sets of mapping
|
| 81 |
+
# 1) Refs: {ref_id: ref}
|
| 82 |
+
# 2) Anns: {ann_id: ann}
|
| 83 |
+
# 3) Imgs: {image_id: image}
|
| 84 |
+
# 4) Cats: {category_id: category_name}
|
| 85 |
+
# 5) Sents: {sent_id: sent}
|
| 86 |
+
# 6) imgToRefs: {image_id: refs}
|
| 87 |
+
# 7) imgToAnns: {image_id: anns}
|
| 88 |
+
# 8) refToAnn: {ref_id: ann}
|
| 89 |
+
# 9) annToRef: {ann_id: ref}
|
| 90 |
+
# 10) catToRefs: {category_id: refs}
|
| 91 |
+
# 11) sentToRef: {sent_id: ref}
|
| 92 |
+
# 12) sentToTokens: {sent_id: tokens}
|
| 93 |
+
print('creating index...')
|
| 94 |
+
# fetch info from instances
|
| 95 |
+
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
| 96 |
+
for ann in self.data['annotations']:
|
| 97 |
+
Anns[ann['id']] = ann
|
| 98 |
+
imgToAnns[ann['image_id']] = imgToAnns.get(
|
| 99 |
+
ann['image_id'], []) + [ann]
|
| 100 |
+
for img in self.data['images']:
|
| 101 |
+
Imgs[img['id']] = img
|
| 102 |
+
for cat in self.data['categories']:
|
| 103 |
+
Cats[cat['id']] = cat['name']
|
| 104 |
+
|
| 105 |
+
# fetch info from refs
|
| 106 |
+
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
| 107 |
+
Sents, sentToRef, sentToTokens = {}, {}, {}
|
| 108 |
+
for ref in self.data['refs']:
|
| 109 |
+
# ids
|
| 110 |
+
ref_id = ref['ref_id']
|
| 111 |
+
ann_id = ref['ann_id']
|
| 112 |
+
category_id = ref['category_id']
|
| 113 |
+
image_id = ref['image_id']
|
| 114 |
+
|
| 115 |
+
# add mapping related to ref
|
| 116 |
+
Refs[ref_id] = ref
|
| 117 |
+
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
| 118 |
+
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
|
| 119 |
+
refToAnn[ref_id] = Anns[ann_id]
|
| 120 |
+
annToRef[ann_id] = ref
|
| 121 |
+
|
| 122 |
+
# add mapping of sent
|
| 123 |
+
for sent in ref['sentences']:
|
| 124 |
+
Sents[sent['sent_id']] = sent
|
| 125 |
+
sentToRef[sent['sent_id']] = ref
|
| 126 |
+
sentToTokens[sent['sent_id']] = sent['tokens']
|
| 127 |
+
|
| 128 |
+
# create class members
|
| 129 |
+
self.Refs = Refs
|
| 130 |
+
self.Anns = Anns
|
| 131 |
+
self.Imgs = Imgs
|
| 132 |
+
self.Cats = Cats
|
| 133 |
+
self.Sents = Sents
|
| 134 |
+
self.imgToRefs = imgToRefs
|
| 135 |
+
self.imgToAnns = imgToAnns
|
| 136 |
+
self.refToAnn = refToAnn
|
| 137 |
+
self.annToRef = annToRef
|
| 138 |
+
self.catToRefs = catToRefs
|
| 139 |
+
self.sentToRef = sentToRef
|
| 140 |
+
self.sentToTokens = sentToTokens
|
| 141 |
+
print('index created.')
|
| 142 |
+
|
| 143 |
+
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
|
| 144 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
| 145 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
| 146 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 147 |
+
|
| 148 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
|
| 149 |
+
refs = self.data['refs']
|
| 150 |
+
else:
|
| 151 |
+
if not len(image_ids) == 0:
|
| 152 |
+
refs = [self.imgToRefs[image_id] for image_id in image_ids]
|
| 153 |
+
else:
|
| 154 |
+
refs = self.data['refs']
|
| 155 |
+
if not len(cat_ids) == 0:
|
| 156 |
+
refs = [ref for ref in refs if ref['category_id'] in cat_ids]
|
| 157 |
+
if not len(ref_ids) == 0:
|
| 158 |
+
refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
|
| 159 |
+
if not len(split) == 0:
|
| 160 |
+
if split in ['testA', 'testB', 'testC']:
|
| 161 |
+
# we also consider testAB, testBC, ...
|
| 162 |
+
refs = [ref for ref in refs if split[-1] in ref['split']]
|
| 163 |
+
elif split in ['testAB', 'testBC', 'testAC']:
|
| 164 |
+
# rarely used I guess...
|
| 165 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
| 166 |
+
elif split == 'test':
|
| 167 |
+
refs = [ref for ref in refs if 'test' in ref['split']]
|
| 168 |
+
elif split == 'train' or split == 'val':
|
| 169 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
| 170 |
+
else:
|
| 171 |
+
print('No such split [{}]'.format(split))
|
| 172 |
+
sys.exit()
|
| 173 |
+
ref_ids = [ref['ref_id'] for ref in refs]
|
| 174 |
+
return ref_ids
|
| 175 |
+
|
| 176 |
+
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
|
| 177 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
| 178 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
| 179 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 180 |
+
|
| 181 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
|
| 182 |
+
ann_ids = [ann['id'] for ann in self.data['annotations']]
|
| 183 |
+
else:
|
| 184 |
+
if not len(image_ids) == 0:
|
| 185 |
+
lists = [self.imgToAnns[image_id]
|
| 186 |
+
for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
|
| 187 |
+
anns = list(itertools.chain.from_iterable(lists))
|
| 188 |
+
else:
|
| 189 |
+
anns = self.data['annotations']
|
| 190 |
+
if not len(cat_ids) == 0:
|
| 191 |
+
anns = [ann for ann in anns if ann['category_id'] in cat_ids]
|
| 192 |
+
ann_ids = [ann['id'] for ann in anns]
|
| 193 |
+
if not len(ref_ids) == 0:
|
| 194 |
+
ids = set(ann_ids).intersection(
|
| 195 |
+
set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
|
| 196 |
+
return ann_ids
|
| 197 |
+
|
| 198 |
+
def getImgIds(self, ref_ids=[]):
|
| 199 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 200 |
+
|
| 201 |
+
if not len(ref_ids) == 0:
|
| 202 |
+
image_ids = list(set([self.Refs[ref_id]['image_id']
|
| 203 |
+
for ref_id in ref_ids]))
|
| 204 |
+
else:
|
| 205 |
+
image_ids = self.Imgs.keys()
|
| 206 |
+
return image_ids
|
| 207 |
+
|
| 208 |
+
def getCatIds(self):
|
| 209 |
+
return self.Cats.keys()
|
| 210 |
+
|
| 211 |
+
def loadRefs(self, ref_ids=[]):
|
| 212 |
+
if type(ref_ids) == list:
|
| 213 |
+
return [self.Refs[ref_id] for ref_id in ref_ids]
|
| 214 |
+
elif type(ref_ids) == int:
|
| 215 |
+
return [self.Refs[ref_ids]]
|
| 216 |
+
|
| 217 |
+
def loadAnns(self, ann_ids=[]):
|
| 218 |
+
if type(ann_ids) == list:
|
| 219 |
+
return [self.Anns[ann_id] for ann_id in ann_ids]
|
| 220 |
+
elif type(ann_ids) == int or type(ann_ids) == unicode:
|
| 221 |
+
return [self.Anns[ann_ids]]
|
| 222 |
+
|
| 223 |
+
def loadImgs(self, image_ids=[]):
|
| 224 |
+
if type(image_ids) == list:
|
| 225 |
+
return [self.Imgs[image_id] for image_id in image_ids]
|
| 226 |
+
elif type(image_ids) == int:
|
| 227 |
+
return [self.Imgs[image_ids]]
|
| 228 |
+
|
| 229 |
+
def loadCats(self, cat_ids=[]):
|
| 230 |
+
if type(cat_ids) == list:
|
| 231 |
+
return [self.Cats[cat_id] for cat_id in cat_ids]
|
| 232 |
+
elif type(cat_ids) == int:
|
| 233 |
+
return [self.Cats[cat_ids]]
|
| 234 |
+
|
| 235 |
+
def getRefBox(self, ref_id):
|
| 236 |
+
ref = self.Refs[ref_id]
|
| 237 |
+
ann = self.refToAnn[ref_id]
|
| 238 |
+
return ann['bbox'] # [x, y, w, h]
|
| 239 |
+
|
| 240 |
+
def showRef(self, ref, seg_box='seg'):
|
| 241 |
+
ax = plt.gca()
|
| 242 |
+
# show image
|
| 243 |
+
image = self.Imgs[ref['image_id']]
|
| 244 |
+
I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
| 245 |
+
ax.imshow(I)
|
| 246 |
+
# show refer expression
|
| 247 |
+
for sid, sent in enumerate(ref['sentences']):
|
| 248 |
+
print('{}. {}'.format(sid+1, sent['sent']))
|
| 249 |
+
# show segmentations
|
| 250 |
+
if seg_box == 'seg':
|
| 251 |
+
ann_id = ref['ann_id']
|
| 252 |
+
ann = self.Anns[ann_id]
|
| 253 |
+
polygons = []
|
| 254 |
+
color = []
|
| 255 |
+
c = 'none'
|
| 256 |
+
if type(ann['segmentation'][0]) == list:
|
| 257 |
+
# polygon used for refcoco*
|
| 258 |
+
for seg in ann['segmentation']:
|
| 259 |
+
poly = np.array(seg).reshape((len(seg)/2, 2))
|
| 260 |
+
polygons.append(Polygon(poly, True, alpha=0.4))
|
| 261 |
+
color.append(c)
|
| 262 |
+
p = PatchCollection(polygons, facecolors=color, edgecolors=(
|
| 263 |
+
1, 1, 0, 0), linewidths=3, alpha=1)
|
| 264 |
+
ax.add_collection(p) # thick yellow polygon
|
| 265 |
+
p = PatchCollection(polygons, facecolors=color, edgecolors=(
|
| 266 |
+
1, 0, 0, 0), linewidths=1, alpha=1)
|
| 267 |
+
ax.add_collection(p) # thin red polygon
|
| 268 |
+
else:
|
| 269 |
+
# mask used for refclef
|
| 270 |
+
rle = ann['segmentation']
|
| 271 |
+
m = mask.decode(rle)
|
| 272 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
| 273 |
+
color_mask = np.array([2.0, 166.0, 101.0])/255
|
| 274 |
+
for i in range(3):
|
| 275 |
+
img[:, :, i] = color_mask[i]
|
| 276 |
+
ax.imshow(np.dstack((img, m*0.5)))
|
| 277 |
+
# show bounding-box
|
| 278 |
+
elif seg_box == 'box':
|
| 279 |
+
ann_id = ref['ann_id']
|
| 280 |
+
ann = self.Anns[ann_id]
|
| 281 |
+
bbox = self.getRefBox(ref['ref_id'])
|
| 282 |
+
box_plot = Rectangle(
|
| 283 |
+
(bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
|
| 284 |
+
ax.add_patch(box_plot)
|
| 285 |
+
|
| 286 |
+
def getMask(self, ref):
|
| 287 |
+
# return mask, area and mask-center
|
| 288 |
+
ann = self.refToAnn[ref['ref_id']]
|
| 289 |
+
image = self.Imgs[ref['image_id']]
|
| 290 |
+
if type(ann['segmentation'][0]) == list: # polygon
|
| 291 |
+
rle = mask.frPyObjects(
|
| 292 |
+
ann['segmentation'], image['height'], image['width'])
|
| 293 |
+
else:
|
| 294 |
+
rle = ann['segmentation']
|
| 295 |
+
m = mask.decode(rle)
|
| 296 |
+
# sometimes there are multiple binary map (corresponding to multiple segs)
|
| 297 |
+
m = np.sum(m, axis=2)
|
| 298 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
| 299 |
+
# compute area
|
| 300 |
+
area = sum(mask.area(rle)) # should be close to ann['area']
|
| 301 |
+
return {'mask': m, 'area': area}
|
| 302 |
+
# # position
|
| 303 |
+
# position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
|
| 304 |
+
# position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
|
| 305 |
+
# # mass position (if there were multiple regions, we use the largest one.)
|
| 306 |
+
# label_m = label(m, connectivity=m.ndim)
|
| 307 |
+
# regions = regionprops(label_m)
|
| 308 |
+
# if len(regions) > 0:
|
| 309 |
+
# largest_id = np.argmax(np.array([props.filled_area for props in regions]))
|
| 310 |
+
# largest_props = regions[largest_id]
|
| 311 |
+
# mass_y, mass_x = largest_props.centroid
|
| 312 |
+
# else:
|
| 313 |
+
# mass_x, mass_y = position_x, position_y
|
| 314 |
+
# # if centroid is not in mask, we find the closest point to it from mask
|
| 315 |
+
# if m[mass_y, mass_x] != 1:
|
| 316 |
+
# print 'Finding closes mask point ...'
|
| 317 |
+
# kernel = np.ones((10, 10),np.uint8)
|
| 318 |
+
# me = cv2.erode(m, kernel, iterations = 1)
|
| 319 |
+
# points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
|
| 320 |
+
# points = np.array(points)
|
| 321 |
+
# dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
|
| 322 |
+
# id = np.argsort(dist)[0]
|
| 323 |
+
# mass_y, mass_x = points[id]
|
| 324 |
+
# # return
|
| 325 |
+
# return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
|
| 326 |
+
# # show image and mask
|
| 327 |
+
# I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
| 328 |
+
# plt.figure()
|
| 329 |
+
# plt.imshow(I)
|
| 330 |
+
# ax = plt.gca()
|
| 331 |
+
# img = np.ones( (m.shape[0], m.shape[1], 3) )
|
| 332 |
+
# color_mask = np.array([2.0,166.0,101.0])/255
|
| 333 |
+
# for i in range(3):
|
| 334 |
+
# img[:,:,i] = color_mask[i]
|
| 335 |
+
# ax.imshow(np.dstack( (img, m*0.5) ))
|
| 336 |
+
# plt.show()
|
| 337 |
+
|
| 338 |
+
def showMask(self, ref):
|
| 339 |
+
M = self.getMask(ref)
|
| 340 |
+
msk = M['mask']
|
| 341 |
+
ax = plt.gca()
|
| 342 |
+
ax.imshow(msk)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
if __name__ == '__main__':
|
| 346 |
+
refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg',
|
| 347 |
+
dataset='refcocog', splitBy='google')
|
| 348 |
+
ref_ids = refer.getRefIds()
|
| 349 |
+
print(len(ref_ids))
|
| 350 |
+
|
| 351 |
+
print(len(refer.Imgs))
|
| 352 |
+
print(len(refer.imgToRefs))
|
| 353 |
+
|
| 354 |
+
ref_ids = refer.getRefIds(split='train')
|
| 355 |
+
print('There are {} training referred objects.' % len(ref_ids))
|
| 356 |
+
|
| 357 |
+
for ref_id in ref_ids:
|
| 358 |
+
ref = refer.loadRefs(ref_id)[0]
|
| 359 |
+
if len(ref['sentences']) < 2:
|
| 360 |
+
continue
|
| 361 |
+
|
| 362 |
+
pprint(ref)
|
| 363 |
+
print('The label is {}.'.format(refer.Cats[ref['category_id']]))
|
| 364 |
+
|
| 365 |
+
# plt.figure()
|
| 366 |
+
# refer.showRef(ref, seg_box='box')
|
| 367 |
+
# plt.show()
|
| 368 |
+
|
| 369 |
+
# plt.figure()
|
| 370 |
+
# refer.showMask(ref)
|
| 371 |
+
# plt.show()
|
datasets/registration/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import (
|
| 2 |
+
register_biomed_datasets
|
| 3 |
+
)
|
datasets/registration/register_biomed_datasets.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
| 3 |
+
# Copyright (c) 2022 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Modified by Xueyan Zou ([email protected])
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import collections
|
| 10 |
+
|
| 11 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 12 |
+
from detectron2.data.datasets import load_sem_seg
|
| 13 |
+
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
|
| 14 |
+
from detectron2.utils.file_io import PathManager
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
_PREDEFINED_SPLITS_BIOMED = {}
|
| 18 |
+
|
| 19 |
+
# example of registering a dataset
|
| 20 |
+
datasets = ['BiomedParseData-Demo', ] # provide name of the dataset under biomedparse_datasets
|
| 21 |
+
splits = ['demo'] # provide split name, e.g., train, test, val. Here there is only one 'demo' split in the example demo dataset
|
| 22 |
+
|
| 23 |
+
# Here we register all the splits of the dataset
|
| 24 |
+
for name in datasets:
|
| 25 |
+
for split in splits:
|
| 26 |
+
dataname = f'biomed_{name.replace("/", "-")}_{split}'
|
| 27 |
+
image_root = f"{name}/{split}"
|
| 28 |
+
ann_root = f"{name}/{split}.json"
|
| 29 |
+
_PREDEFINED_SPLITS_BIOMED[dataname] = (image_root, ann_root)
|
| 30 |
+
# The resulting dataset name is: biomed_BiomedParseData-Demo_demo
|
| 31 |
+
|
| 32 |
+
# # Add your dataset here
|
| 33 |
+
# datasets = ['YOUR_DATASET_NAME', ] # provide name of the dataset under biomedparse_datasets
|
| 34 |
+
# splits = ['train', 'test'] # provide split name, e.g., train, test, val
|
| 35 |
+
|
| 36 |
+
# # Here we register all the splits of the dataset
|
| 37 |
+
# for name in datasets:
|
| 38 |
+
# for split in splits:
|
| 39 |
+
# dataname = f'biomed_{name.replace("/", "-")}_{split}'
|
| 40 |
+
# image_root = f"{name}/{split}"
|
| 41 |
+
# ann_root = f"{name}/{split}.json"
|
| 42 |
+
# _PREDEFINED_SPLITS_BIOMED[dataname] = (image_root, ann_root)
|
| 43 |
+
# # The resulting dataset names are: biomed_YOUR_DATASET_NAME_train, biomed_YOUR_DATASET_NAME_test
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_metadata():
|
| 47 |
+
meta = {}
|
| 48 |
+
return meta
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_biomed_json(image_root, annot_json, metadata):
|
| 52 |
+
"""
|
| 53 |
+
Args:
|
| 54 |
+
image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
|
| 55 |
+
gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
|
| 56 |
+
json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
|
| 57 |
+
Returns:
|
| 58 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
| 59 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
with PathManager.open(annot_json) as f:
|
| 63 |
+
json_info = json.load(f)
|
| 64 |
+
|
| 65 |
+
# build dictionary for grounding
|
| 66 |
+
grd_dict = collections.defaultdict(list)
|
| 67 |
+
for grd_ann in json_info['annotations']:
|
| 68 |
+
image_id = int(grd_ann["image_id"])
|
| 69 |
+
grd_dict[image_id].append(grd_ann)
|
| 70 |
+
|
| 71 |
+
mask_root = image_root + '_mask'
|
| 72 |
+
ret = []
|
| 73 |
+
for image in json_info["images"]:
|
| 74 |
+
image_id = int(image["id"])
|
| 75 |
+
image_file = os.path.join(image_root, image['file_name'])
|
| 76 |
+
grounding_anno = grd_dict[image_id]
|
| 77 |
+
for ann in grounding_anno:
|
| 78 |
+
if 'mask_file' not in ann:
|
| 79 |
+
ann['mask_file'] = image['file_name']
|
| 80 |
+
ann['mask_file'] = os.path.join(mask_root, ann['mask_file'])
|
| 81 |
+
ret.append(
|
| 82 |
+
{
|
| 83 |
+
"file_name": image_file,
|
| 84 |
+
"image_id": image_id,
|
| 85 |
+
"grounding_info": [ann],
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
+
assert len(ret), f"No images found in {image_root}!"
|
| 89 |
+
assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
|
| 90 |
+
return ret
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def register_biomed(
|
| 94 |
+
name, metadata, image_root, annot_json):
|
| 95 |
+
DatasetCatalog.register(
|
| 96 |
+
name,
|
| 97 |
+
lambda: load_biomed_json(image_root, annot_json, metadata),
|
| 98 |
+
)
|
| 99 |
+
MetadataCatalog.get(name).set(
|
| 100 |
+
image_root=image_root,
|
| 101 |
+
json_file=annot_json,
|
| 102 |
+
evaluator_type="grounding_refcoco",
|
| 103 |
+
ignore_label=255,
|
| 104 |
+
label_divisor=1000,
|
| 105 |
+
**metadata,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def register_all_biomed(root):
|
| 110 |
+
for (
|
| 111 |
+
prefix,
|
| 112 |
+
(image_root, annot_root),
|
| 113 |
+
) in _PREDEFINED_SPLITS_BIOMED.items():
|
| 114 |
+
register_biomed(
|
| 115 |
+
prefix,
|
| 116 |
+
get_metadata(),
|
| 117 |
+
os.path.join(root, image_root),
|
| 118 |
+
os.path.join(root, annot_root),
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
_root = os.getenv("DATASET", "datasets")
|
| 123 |
+
register_all_biomed(_root)
|
datasets/semseg_loader.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import scipy.io
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def load_semseg(filename, loader_type):
|
| 6 |
+
if loader_type == 'PIL':
|
| 7 |
+
semseg = np.array(Image.open(filename), dtype=np.int)
|
| 8 |
+
elif loader_type == 'MAT':
|
| 9 |
+
semseg = scipy.io.loadmat(filename)['LabelMap']
|
| 10 |
+
return semseg
|
datasets/utils/refcoco2json.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from refer import REFER
|
| 4 |
+
|
| 5 |
+
coco_root = '/pth/to/coco'
|
| 6 |
+
ref_root = '/pth/to/refcocoseg'
|
| 7 |
+
|
| 8 |
+
coco_train_annot = json.load(open(os.path.join(coco_root, 'annotations/instances_train2017.json')))
|
| 9 |
+
coco_train_id = []
|
| 10 |
+
image_annot = {}
|
| 11 |
+
for i in range(len(coco_train_annot['images'])):
|
| 12 |
+
coco_train_id.append(coco_train_annot['images'][i]['id'])
|
| 13 |
+
image_annot[coco_train_annot['images'][i]['id']] = coco_train_annot['images'][i]
|
| 14 |
+
|
| 15 |
+
refg = REFER(data_root=ref_root,
|
| 16 |
+
dataset='refcocog', splitBy='umd')
|
| 17 |
+
refg_val_ids = refg.getRefIds(split='val')
|
| 18 |
+
|
| 19 |
+
full_anno = []
|
| 20 |
+
for ref_id in refg_val_ids:
|
| 21 |
+
ref = refg.loadRefs(ref_id)[0]
|
| 22 |
+
anno = refg.refToAnn[ref_id]
|
| 23 |
+
anno.update(ref)
|
| 24 |
+
full_anno.append(anno)
|
| 25 |
+
|
| 26 |
+
imageid_list = []
|
| 27 |
+
final_anno = {}
|
| 28 |
+
for anno in full_anno:
|
| 29 |
+
imageid_list += [anno['image_id']]
|
| 30 |
+
final_anno[anno['ann_id']] = anno
|
| 31 |
+
|
| 32 |
+
annotations = [value for key, value in final_anno.items()]
|
| 33 |
+
|
| 34 |
+
iamges = []
|
| 35 |
+
for image_id in list(set(imageid_list)):
|
| 36 |
+
iamges += [image_annot[image_id]]
|
| 37 |
+
|
| 38 |
+
outputs = {'images': iamges, 'annotations': annotations}
|
| 39 |
+
print(len(iamges))
|
| 40 |
+
print(len(annotations))
|
| 41 |
+
json.dump(outputs, open(os.path.join(coco_root, 'annotations/refcocog_umd_train.json'), 'w'))
|
datasets/utils/refer.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code is modified from https://github.com/lichengunc/refer, and with minor modification of python2/3 format
|
| 2 |
+
__author__ = 'licheng'
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
This interface provides access to four datasets:
|
| 6 |
+
1) refclef
|
| 7 |
+
2) refcoco
|
| 8 |
+
3) refcoco+
|
| 9 |
+
4) refcocog
|
| 10 |
+
split by unc and google
|
| 11 |
+
|
| 12 |
+
The following API functions are defined:
|
| 13 |
+
REFER - REFER api class
|
| 14 |
+
getRefIds - get ref ids that satisfy given filter conditions.
|
| 15 |
+
getAnnIds - get ann ids that satisfy given filter conditions.
|
| 16 |
+
getImgIds - get image ids that satisfy given filter conditions.
|
| 17 |
+
getCatIds - get category ids that satisfy given filter conditions.
|
| 18 |
+
loadRefs - load refs with the specified ref ids.
|
| 19 |
+
loadAnns - load anns with the specified ann ids.
|
| 20 |
+
loadImgs - load images with the specified image ids.
|
| 21 |
+
loadCats - load category names with the specified category ids.
|
| 22 |
+
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
|
| 23 |
+
showRef - show image, segmentation or box of the referred object with the ref
|
| 24 |
+
getMask - get mask and area of the referred object given ref
|
| 25 |
+
showMask - show mask of the referred object given ref
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from doctest import REPORT_ONLY_FIRST_FAILURE
|
| 29 |
+
import sys
|
| 30 |
+
import os.path as osp
|
| 31 |
+
import json
|
| 32 |
+
import pickle
|
| 33 |
+
import time
|
| 34 |
+
import itertools
|
| 35 |
+
import skimage.io as io
|
| 36 |
+
import matplotlib.pyplot as plt
|
| 37 |
+
from matplotlib.collections import PatchCollection
|
| 38 |
+
from matplotlib.patches import Polygon, Rectangle
|
| 39 |
+
from pprint import pprint
|
| 40 |
+
import numpy as np
|
| 41 |
+
from pycocotools import mask
|
| 42 |
+
# import cv2
|
| 43 |
+
# from skimage.measure import label, regionprops
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class REFER:
|
| 47 |
+
def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
|
| 48 |
+
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
|
| 49 |
+
# also provide dataset name and splitBy information
|
| 50 |
+
# e.g., dataset = 'refcoco', splitBy = 'unc'
|
| 51 |
+
print('loading dataset {} into memory...'.format(dataset))
|
| 52 |
+
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
| 53 |
+
self.DATA_DIR = osp.join(data_root, dataset)
|
| 54 |
+
if dataset in ['refcoco', 'refcoco+', 'refcocog']:
|
| 55 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
|
| 56 |
+
elif dataset == 'refclef':
|
| 57 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
|
| 58 |
+
else:
|
| 59 |
+
print('No refer dataset is called [{}]'.format(dataset))
|
| 60 |
+
sys.exit()
|
| 61 |
+
|
| 62 |
+
# load refs from data/dataset/refs(dataset).json
|
| 63 |
+
tic = time.time()
|
| 64 |
+
ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
|
| 65 |
+
self.data = {}
|
| 66 |
+
self.data['dataset'] = dataset
|
| 67 |
+
self.data['refs'] = pickle.load(open(ref_file, 'rb'))
|
| 68 |
+
|
| 69 |
+
# load annotations from data/dataset/instances.json
|
| 70 |
+
instances_file = osp.join(self.DATA_DIR, 'instances.json')
|
| 71 |
+
instances = json.load(open(instances_file, 'r'))
|
| 72 |
+
self.data['images'] = instances['images']
|
| 73 |
+
self.data['annotations'] = instances['annotations']
|
| 74 |
+
self.data['categories'] = instances['categories']
|
| 75 |
+
|
| 76 |
+
# create index
|
| 77 |
+
self.createIndex()
|
| 78 |
+
print('DONE (t=%.2fs)'.format(time.time()-tic))
|
| 79 |
+
|
| 80 |
+
def createIndex(self):
|
| 81 |
+
# create sets of mapping
|
| 82 |
+
# 1) Refs: {ref_id: ref}
|
| 83 |
+
# 2) Anns: {ann_id: ann}
|
| 84 |
+
# 3) Imgs: {image_id: image}
|
| 85 |
+
# 4) Cats: {category_id: category_name}
|
| 86 |
+
# 5) Sents: {sent_id: sent}
|
| 87 |
+
# 6) imgToRefs: {image_id: refs}
|
| 88 |
+
# 7) imgToAnns: {image_id: anns}
|
| 89 |
+
# 8) refToAnn: {ref_id: ann}
|
| 90 |
+
# 9) annToRef: {ann_id: ref}
|
| 91 |
+
# 10) catToRefs: {category_id: refs}
|
| 92 |
+
# 11) sentToRef: {sent_id: ref}
|
| 93 |
+
# 12) sentToTokens: {sent_id: tokens}
|
| 94 |
+
print('creating index...')
|
| 95 |
+
# fetch info from instances
|
| 96 |
+
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
| 97 |
+
for ann in self.data['annotations']:
|
| 98 |
+
Anns[ann['id']] = ann
|
| 99 |
+
imgToAnns[ann['image_id']] = imgToAnns.get(
|
| 100 |
+
ann['image_id'], []) + [ann]
|
| 101 |
+
for img in self.data['images']:
|
| 102 |
+
Imgs[img['id']] = img
|
| 103 |
+
for cat in self.data['categories']:
|
| 104 |
+
Cats[cat['id']] = cat['name']
|
| 105 |
+
|
| 106 |
+
# fetch info from refs
|
| 107 |
+
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
| 108 |
+
Sents, sentToRef, sentToTokens = {}, {}, {}
|
| 109 |
+
for ref in self.data['refs']:
|
| 110 |
+
# ids
|
| 111 |
+
ref_id = ref['ref_id']
|
| 112 |
+
ann_id = ref['ann_id']
|
| 113 |
+
category_id = ref['category_id']
|
| 114 |
+
image_id = ref['image_id']
|
| 115 |
+
|
| 116 |
+
# add mapping related to ref
|
| 117 |
+
Refs[ref_id] = ref
|
| 118 |
+
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
| 119 |
+
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
|
| 120 |
+
refToAnn[ref_id] = Anns[ann_id]
|
| 121 |
+
annToRef[ann_id] = ref
|
| 122 |
+
|
| 123 |
+
# add mapping of sent
|
| 124 |
+
for sent in ref['sentences']:
|
| 125 |
+
Sents[sent['sent_id']] = sent
|
| 126 |
+
sentToRef[sent['sent_id']] = ref
|
| 127 |
+
sentToTokens[sent['sent_id']] = sent['tokens']
|
| 128 |
+
|
| 129 |
+
# create class members
|
| 130 |
+
self.Refs = Refs
|
| 131 |
+
self.Anns = Anns
|
| 132 |
+
self.Imgs = Imgs
|
| 133 |
+
self.Cats = Cats
|
| 134 |
+
self.Sents = Sents
|
| 135 |
+
self.imgToRefs = imgToRefs
|
| 136 |
+
self.imgToAnns = imgToAnns
|
| 137 |
+
self.refToAnn = refToAnn
|
| 138 |
+
self.annToRef = annToRef
|
| 139 |
+
self.catToRefs = catToRefs
|
| 140 |
+
self.sentToRef = sentToRef
|
| 141 |
+
self.sentToTokens = sentToTokens
|
| 142 |
+
print('index created.')
|
| 143 |
+
|
| 144 |
+
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
|
| 145 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
| 146 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
| 147 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 148 |
+
|
| 149 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
|
| 150 |
+
refs = self.data['refs']
|
| 151 |
+
else:
|
| 152 |
+
if not len(image_ids) == 0:
|
| 153 |
+
refs = [self.imgToRefs[image_id] for image_id in image_ids]
|
| 154 |
+
else:
|
| 155 |
+
refs = self.data['refs']
|
| 156 |
+
if not len(cat_ids) == 0:
|
| 157 |
+
refs = [ref for ref in refs if ref['category_id'] in cat_ids]
|
| 158 |
+
if not len(ref_ids) == 0:
|
| 159 |
+
refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
|
| 160 |
+
if not len(split) == 0:
|
| 161 |
+
if split in ['testA', 'testB', 'testC']:
|
| 162 |
+
# we also consider testAB, testBC, ...
|
| 163 |
+
refs = [ref for ref in refs if split[-1] in ref['split']]
|
| 164 |
+
elif split in ['testAB', 'testBC', 'testAC']:
|
| 165 |
+
# rarely used I guess...
|
| 166 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
| 167 |
+
elif split == 'test':
|
| 168 |
+
refs = [ref for ref in refs if 'test' in ref['split']]
|
| 169 |
+
elif split == 'train' or split == 'val':
|
| 170 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
| 171 |
+
else:
|
| 172 |
+
print('No such split [{}]'.format(split))
|
| 173 |
+
sys.exit()
|
| 174 |
+
ref_ids = [ref['ref_id'] for ref in refs]
|
| 175 |
+
return ref_ids
|
| 176 |
+
|
| 177 |
+
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
|
| 178 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
| 179 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
| 180 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 181 |
+
|
| 182 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
|
| 183 |
+
ann_ids = [ann['id'] for ann in self.data['annotations']]
|
| 184 |
+
else:
|
| 185 |
+
if not len(image_ids) == 0:
|
| 186 |
+
lists = [self.imgToAnns[image_id]
|
| 187 |
+
for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
|
| 188 |
+
anns = list(itertools.chain.from_iterable(lists))
|
| 189 |
+
else:
|
| 190 |
+
anns = self.data['annotations']
|
| 191 |
+
if not len(cat_ids) == 0:
|
| 192 |
+
anns = [ann for ann in anns if ann['category_id'] in cat_ids]
|
| 193 |
+
ann_ids = [ann['id'] for ann in anns]
|
| 194 |
+
if not len(ref_ids) == 0:
|
| 195 |
+
ids = set(ann_ids).intersection(
|
| 196 |
+
set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
|
| 197 |
+
return ann_ids
|
| 198 |
+
|
| 199 |
+
def getImgIds(self, ref_ids=[]):
|
| 200 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 201 |
+
|
| 202 |
+
if not len(ref_ids) == 0:
|
| 203 |
+
image_ids = list(set([self.Refs[ref_id]['image_id']
|
| 204 |
+
for ref_id in ref_ids]))
|
| 205 |
+
else:
|
| 206 |
+
image_ids = self.Imgs.keys()
|
| 207 |
+
return image_ids
|
| 208 |
+
|
| 209 |
+
def getCatIds(self):
|
| 210 |
+
return self.Cats.keys()
|
| 211 |
+
|
| 212 |
+
def loadRefs(self, ref_ids=[]):
|
| 213 |
+
if type(ref_ids) == list:
|
| 214 |
+
return [self.Refs[ref_id] for ref_id in ref_ids]
|
| 215 |
+
elif type(ref_ids) == int:
|
| 216 |
+
return [self.Refs[ref_ids]]
|
| 217 |
+
|
| 218 |
+
def loadAnns(self, ann_ids=[]):
|
| 219 |
+
if type(ann_ids) == list:
|
| 220 |
+
return [self.Anns[ann_id] for ann_id in ann_ids]
|
| 221 |
+
elif type(ann_ids) == int or type(ann_ids) == unicode:
|
| 222 |
+
return [self.Anns[ann_ids]]
|
| 223 |
+
|
| 224 |
+
def loadImgs(self, image_ids=[]):
|
| 225 |
+
if type(image_ids) == list:
|
| 226 |
+
return [self.Imgs[image_id] for image_id in image_ids]
|
| 227 |
+
elif type(image_ids) == int:
|
| 228 |
+
return [self.Imgs[image_ids]]
|
| 229 |
+
|
| 230 |
+
def loadCats(self, cat_ids=[]):
|
| 231 |
+
if type(cat_ids) == list:
|
| 232 |
+
return [self.Cats[cat_id] for cat_id in cat_ids]
|
| 233 |
+
elif type(cat_ids) == int:
|
| 234 |
+
return [self.Cats[cat_ids]]
|
| 235 |
+
|
| 236 |
+
def getRefBox(self, ref_id):
|
| 237 |
+
ref = self.Refs[ref_id]
|
| 238 |
+
ann = self.refToAnn[ref_id]
|
| 239 |
+
return ann['bbox'] # [x, y, w, h]
|
| 240 |
+
|
| 241 |
+
def showRef(self, ref, seg_box='seg'):
|
| 242 |
+
ax = plt.gca()
|
| 243 |
+
# show image
|
| 244 |
+
image = self.Imgs[ref['image_id']]
|
| 245 |
+
I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
| 246 |
+
ax.imshow(I)
|
| 247 |
+
# show refer expression
|
| 248 |
+
for sid, sent in enumerate(ref['sentences']):
|
| 249 |
+
print('{}. {}'.format(sid+1, sent['sent']))
|
| 250 |
+
# show segmentations
|
| 251 |
+
if seg_box == 'seg':
|
| 252 |
+
ann_id = ref['ann_id']
|
| 253 |
+
ann = self.Anns[ann_id]
|
| 254 |
+
polygons = []
|
| 255 |
+
color = []
|
| 256 |
+
c = 'none'
|
| 257 |
+
if type(ann['segmentation'][0]) == list:
|
| 258 |
+
# polygon used for refcoco*
|
| 259 |
+
for seg in ann['segmentation']:
|
| 260 |
+
poly = np.array(seg).reshape((len(seg)/2, 2))
|
| 261 |
+
polygons.append(Polygon(poly, True, alpha=0.4))
|
| 262 |
+
color.append(c)
|
| 263 |
+
p = PatchCollection(polygons, facecolors=color, edgecolors=(
|
| 264 |
+
1, 1, 0, 0), linewidths=3, alpha=1)
|
| 265 |
+
ax.add_collection(p) # thick yellow polygon
|
| 266 |
+
p = PatchCollection(polygons, facecolors=color, edgecolors=(
|
| 267 |
+
1, 0, 0, 0), linewidths=1, alpha=1)
|
| 268 |
+
ax.add_collection(p) # thin red polygon
|
| 269 |
+
else:
|
| 270 |
+
# mask used for refclef
|
| 271 |
+
rle = ann['segmentation']
|
| 272 |
+
m = mask.decode(rle)
|
| 273 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
| 274 |
+
color_mask = np.array([2.0, 166.0, 101.0])/255
|
| 275 |
+
for i in range(3):
|
| 276 |
+
img[:, :, i] = color_mask[i]
|
| 277 |
+
ax.imshow(np.dstack((img, m*0.5)))
|
| 278 |
+
# show bounding-box
|
| 279 |
+
elif seg_box == 'box':
|
| 280 |
+
ann_id = ref['ann_id']
|
| 281 |
+
ann = self.Anns[ann_id]
|
| 282 |
+
bbox = self.getRefBox(ref['ref_id'])
|
| 283 |
+
box_plot = Rectangle(
|
| 284 |
+
(bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
|
| 285 |
+
ax.add_patch(box_plot)
|
| 286 |
+
|
| 287 |
+
def getMask(self, ref):
|
| 288 |
+
# return mask, area and mask-center
|
| 289 |
+
ann = self.refToAnn[ref['ref_id']]
|
| 290 |
+
image = self.Imgs[ref['image_id']]
|
| 291 |
+
if type(ann['segmentation'][0]) == list: # polygon
|
| 292 |
+
rle = mask.frPyObjects(
|
| 293 |
+
ann['segmentation'], image['height'], image['width'])
|
| 294 |
+
else:
|
| 295 |
+
rle = ann['segmentation']
|
| 296 |
+
m = mask.decode(rle)
|
| 297 |
+
# sometimes there are multiple binary map (corresponding to multiple segs)
|
| 298 |
+
m = np.sum(m, axis=2)
|
| 299 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
| 300 |
+
# compute area
|
| 301 |
+
area = sum(mask.area(rle)) # should be close to ann['area']
|
| 302 |
+
return {'mask': m, 'area': area}
|
| 303 |
+
# # position
|
| 304 |
+
# position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
|
| 305 |
+
# position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
|
| 306 |
+
# # mass position (if there were multiple regions, we use the largest one.)
|
| 307 |
+
# label_m = label(m, connectivity=m.ndim)
|
| 308 |
+
# regions = regionprops(label_m)
|
| 309 |
+
# if len(regions) > 0:
|
| 310 |
+
# largest_id = np.argmax(np.array([props.filled_area for props in regions]))
|
| 311 |
+
# largest_props = regions[largest_id]
|
| 312 |
+
# mass_y, mass_x = largest_props.centroid
|
| 313 |
+
# else:
|
| 314 |
+
# mass_x, mass_y = position_x, position_y
|
| 315 |
+
# # if centroid is not in mask, we find the closest point to it from mask
|
| 316 |
+
# if m[mass_y, mass_x] != 1:
|
| 317 |
+
# print 'Finding closes mask point ...'
|
| 318 |
+
# kernel = np.ones((10, 10),np.uint8)
|
| 319 |
+
# me = cv2.erode(m, kernel, iterations = 1)
|
| 320 |
+
# points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
|
| 321 |
+
# points = np.array(points)
|
| 322 |
+
# dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
|
| 323 |
+
# id = np.argsort(dist)[0]
|
| 324 |
+
# mass_y, mass_x = points[id]
|
| 325 |
+
# # return
|
| 326 |
+
# return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
|
| 327 |
+
# # show image and mask
|
| 328 |
+
# I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
| 329 |
+
# plt.figure()
|
| 330 |
+
# plt.imshow(I)
|
| 331 |
+
# ax = plt.gca()
|
| 332 |
+
# img = np.ones( (m.shape[0], m.shape[1], 3) )
|
| 333 |
+
# color_mask = np.array([2.0,166.0,101.0])/255
|
| 334 |
+
# for i in range(3):
|
| 335 |
+
# img[:,:,i] = color_mask[i]
|
| 336 |
+
# ax.imshow(np.dstack( (img, m*0.5) ))
|
| 337 |
+
# plt.show()
|
| 338 |
+
|
| 339 |
+
def showMask(self, ref):
|
| 340 |
+
M = self.getMask(ref)
|
| 341 |
+
msk = M['mask']
|
| 342 |
+
ax = plt.gca()
|
| 343 |
+
ax.imshow(msk)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
if __name__ == '__main__':
|
| 347 |
+
refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg',
|
| 348 |
+
dataset='refcocog', splitBy='google')
|
| 349 |
+
ref_ids = refer.getRefIds()
|
| 350 |
+
print(len(ref_ids))
|
| 351 |
+
|
| 352 |
+
print(len(refer.Imgs))
|
| 353 |
+
print(len(refer.imgToRefs))
|
| 354 |
+
|
| 355 |
+
ref_ids = refer.getRefIds(split='train')
|
| 356 |
+
print('There are {} training referred objects.' % len(ref_ids))
|
| 357 |
+
|
| 358 |
+
for ref_id in ref_ids:
|
| 359 |
+
ref = refer.loadRefs(ref_id)[0]
|
| 360 |
+
if len(ref['sentences']) < 2:
|
| 361 |
+
continue
|
| 362 |
+
|
| 363 |
+
pprint(ref)
|
| 364 |
+
print('The label is {}.'.format(refer.Cats[ref['category_id']]))
|
| 365 |
+
|
| 366 |
+
# plt.figure()
|
| 367 |
+
# refer.showRef(ref, seg_box='box')
|
| 368 |
+
# plt.show()
|
| 369 |
+
|
| 370 |
+
# plt.figure()
|
| 371 |
+
# refer.showMask(ref)
|
| 372 |
+
# plt.show()
|
datasets/visual_sampler/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .sampler import ShapeSampler
|
| 2 |
+
from .simpleclick_sampler import SimpleClickSampler
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def build_shape_sampler(cfg, **kwargs):
|
| 6 |
+
sampler_name = cfg['STROKE_SAMPLER']['EVAL']['MODE']
|
| 7 |
+
if sampler_name == 'random':
|
| 8 |
+
return ShapeSampler(cfg, **kwargs)
|
| 9 |
+
elif sampler_name in ['best', 'best_random']:
|
| 10 |
+
return SimpleClickSampler(cfg, **kwargs)
|
| 11 |
+
else:
|
| 12 |
+
assert False, "not implemented"
|
datasets/visual_sampler/circle.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from .mask_generators import get_mask_by_input_strokes
|
| 5 |
+
|
| 6 |
+
class Circle:
|
| 7 |
+
def __init__(self, cfg, is_train=True):
|
| 8 |
+
self.num_stroke = cfg['STROKE_SAMPLER']['CIRCLE']['NUM_STROKES']
|
| 9 |
+
self.stroke_preset = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PRESET']
|
| 10 |
+
self.stroke_prob = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PROB']
|
| 11 |
+
self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
|
| 12 |
+
self.is_train = is_train
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def get_stroke_preset(stroke_preset):
|
| 16 |
+
if stroke_preset == 'object_like':
|
| 17 |
+
return {
|
| 18 |
+
"nVertexBound": [5, 30],
|
| 19 |
+
"maxHeadSpeed": 15,
|
| 20 |
+
"maxHeadAcceleration": (10, 1.5),
|
| 21 |
+
"brushWidthBound": (20, 50),
|
| 22 |
+
"nMovePointRatio": 0.5,
|
| 23 |
+
"maxPiontMove": 10,
|
| 24 |
+
"maxLineAcceleration": (5, 0.5),
|
| 25 |
+
"boarderGap": None,
|
| 26 |
+
"maxInitSpeed": 10,
|
| 27 |
+
}
|
| 28 |
+
elif stroke_preset == 'object_like_middle':
|
| 29 |
+
return {
|
| 30 |
+
"nVertexBound": [5, 15],
|
| 31 |
+
"maxHeadSpeed": 8,
|
| 32 |
+
"maxHeadAcceleration": (4, 1.5),
|
| 33 |
+
"brushWidthBound": (20, 50),
|
| 34 |
+
"nMovePointRatio": 0.5,
|
| 35 |
+
"maxPiontMove": 5,
|
| 36 |
+
"maxLineAcceleration": (5, 0.5),
|
| 37 |
+
"boarderGap": None,
|
| 38 |
+
"maxInitSpeed": 10,
|
| 39 |
+
}
|
| 40 |
+
elif stroke_preset == 'object_like_small':
|
| 41 |
+
return {
|
| 42 |
+
"nVertexBound": [5, 20],
|
| 43 |
+
"maxHeadSpeed": 7,
|
| 44 |
+
"maxHeadAcceleration": (3.5, 1.5),
|
| 45 |
+
"brushWidthBound": (10, 30),
|
| 46 |
+
"nMovePointRatio": 0.5,
|
| 47 |
+
"maxPiontMove": 5,
|
| 48 |
+
"maxLineAcceleration": (3, 0.5),
|
| 49 |
+
"boarderGap": None,
|
| 50 |
+
"maxInitSpeed": 4,
|
| 51 |
+
}
|
| 52 |
+
else:
|
| 53 |
+
raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
|
| 54 |
+
|
| 55 |
+
def get_random_points_from_mask(self, mask, n=5):
|
| 56 |
+
h,w = mask.shape
|
| 57 |
+
view_mask = mask.reshape(h*w)
|
| 58 |
+
non_zero_idx = view_mask.nonzero()[:,0]
|
| 59 |
+
selected_idx = torch.randperm(len(non_zero_idx))[:n]
|
| 60 |
+
non_zero_idx = non_zero_idx[selected_idx]
|
| 61 |
+
y = (non_zero_idx // w)*1.0
|
| 62 |
+
x = (non_zero_idx % w)*1.0
|
| 63 |
+
return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
|
| 64 |
+
|
| 65 |
+
def draw(self, mask=None, box=None):
|
| 66 |
+
if mask.sum() < 10: # if mask is nearly empty
|
| 67 |
+
return torch.zeros(mask.shape).bool()
|
| 68 |
+
if not self.is_train:
|
| 69 |
+
return self.draw_eval(mask=mask, box=box)
|
| 70 |
+
stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
|
| 71 |
+
preset = Circle.get_stroke_preset(stroke_preset_name)
|
| 72 |
+
nStroke = min(random.randint(1, self.num_stroke), mask.sum().item())
|
| 73 |
+
h,w = mask.shape
|
| 74 |
+
points = self.get_random_points_from_mask(mask, n=nStroke)
|
| 75 |
+
rand_mask = get_mask_by_input_strokes(
|
| 76 |
+
init_points=points,
|
| 77 |
+
imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
|
| 78 |
+
rand_mask = (~torch.from_numpy(rand_mask)) * mask
|
| 79 |
+
return rand_mask
|
| 80 |
+
|
| 81 |
+
def draw_eval(self, mask=None, box=None):
|
| 82 |
+
stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
|
| 83 |
+
preset = Circle.get_stroke_preset(stroke_preset_name)
|
| 84 |
+
nStroke = min(self.max_eval, mask.sum().item())
|
| 85 |
+
h,w = mask.shape
|
| 86 |
+
points = self.get_random_points_from_mask(mask, n=nStroke)
|
| 87 |
+
rand_masks = []
|
| 88 |
+
for i in range(len(points)):
|
| 89 |
+
rand_mask = get_mask_by_input_strokes(
|
| 90 |
+
init_points=points[:i+1],
|
| 91 |
+
imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points[:i+1])), **preset)
|
| 92 |
+
rand_masks += [(~torch.from_numpy(rand_mask)) * mask]
|
| 93 |
+
return torch.stack(rand_masks)
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def draw_by_points(points, mask, h, w):
|
| 97 |
+
stroke_preset_name = random.choices(['object_like', 'object_like_middle', 'object_like_small'], weights=[0.33,0.33,0.33], k=1)[0] # select which kind of object to use
|
| 98 |
+
preset = Circle.get_stroke_preset(stroke_preset_name)
|
| 99 |
+
rand_mask = get_mask_by_input_strokes(
|
| 100 |
+
init_points=points,
|
| 101 |
+
imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
|
| 102 |
+
rand_masks = (~torch.from_numpy(rand_mask)) * mask
|
| 103 |
+
return rand_masks
|
| 104 |
+
|
| 105 |
+
def __repr__(self,):
|
| 106 |
+
return 'circle'
|
datasets/visual_sampler/mask_generators.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
from PIL import Image, ImageDraw
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_mask_by_input_strokes(
|
| 7 |
+
init_points, imageWidth=320, imageHeight=180, nStroke=5,
|
| 8 |
+
nVertexBound=[10, 30], maxHeadSpeed=15, maxHeadAcceleration=(15, 0.5),
|
| 9 |
+
brushWidthBound=(5, 20), boarderGap=None, nMovePointRatio=0.5, maxPiontMove=10,
|
| 10 |
+
maxLineAcceleration=5, maxInitSpeed=5
|
| 11 |
+
):
|
| 12 |
+
'''
|
| 13 |
+
Get video masks by random strokes which move randomly between each
|
| 14 |
+
frame, including the whole stroke and its control points
|
| 15 |
+
|
| 16 |
+
Parameters
|
| 17 |
+
----------
|
| 18 |
+
imageWidth: Image width
|
| 19 |
+
imageHeight: Image height
|
| 20 |
+
nStroke: Number of drawed lines
|
| 21 |
+
nVertexBound: Lower/upper bound of number of control points for each line
|
| 22 |
+
maxHeadSpeed: Max head speed when creating control points
|
| 23 |
+
maxHeadAcceleration: Max acceleration applying on the current head point (
|
| 24 |
+
a head point and its velosity decides the next point)
|
| 25 |
+
brushWidthBound (min, max): Bound of width for each stroke
|
| 26 |
+
boarderGap: The minimum gap between image boarder and drawed lines
|
| 27 |
+
nMovePointRatio: The ratio of control points to move for next frames
|
| 28 |
+
maxPiontMove: The magnitude of movement for control points for next frames
|
| 29 |
+
maxLineAcceleration: The magnitude of acceleration for the whole line
|
| 30 |
+
|
| 31 |
+
Examples
|
| 32 |
+
----------
|
| 33 |
+
object_like_setting = {
|
| 34 |
+
"nVertexBound": [5, 20],
|
| 35 |
+
"maxHeadSpeed": 15,
|
| 36 |
+
"maxHeadAcceleration": (15, 3.14),
|
| 37 |
+
"brushWidthBound": (30, 50),
|
| 38 |
+
"nMovePointRatio": 0.5,
|
| 39 |
+
"maxPiontMove": 10,
|
| 40 |
+
"maxLineAcceleration": (5, 0.5),
|
| 41 |
+
"boarderGap": 20,
|
| 42 |
+
"maxInitSpeed": 10,
|
| 43 |
+
}
|
| 44 |
+
rand_curve_setting = {
|
| 45 |
+
"nVertexBound": [10, 30],
|
| 46 |
+
"maxHeadSpeed": 20,
|
| 47 |
+
"maxHeadAcceleration": (15, 0.5),
|
| 48 |
+
"brushWidthBound": (3, 10),
|
| 49 |
+
"nMovePointRatio": 0.5,
|
| 50 |
+
"maxPiontMove": 3,
|
| 51 |
+
"maxLineAcceleration": (5, 0.5),
|
| 52 |
+
"boarderGap": 20,
|
| 53 |
+
"maxInitSpeed": 6
|
| 54 |
+
}
|
| 55 |
+
get_video_masks_by_moving_random_stroke(video_len=5, nStroke=3, **object_like_setting)
|
| 56 |
+
'''
|
| 57 |
+
# Initilize a set of control points to draw the first mask
|
| 58 |
+
mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
|
| 59 |
+
control_points_set = []
|
| 60 |
+
for i in range(nStroke):
|
| 61 |
+
brushWidth = np.random.randint(brushWidthBound[0], brushWidthBound[1])
|
| 62 |
+
Xs, Ys, velocity = get_random_stroke_control_points(
|
| 63 |
+
init_point=init_points[i],
|
| 64 |
+
imageWidth=imageWidth, imageHeight=imageHeight,
|
| 65 |
+
nVertexBound=nVertexBound, maxHeadSpeed=maxHeadSpeed,
|
| 66 |
+
maxHeadAcceleration=maxHeadAcceleration, boarderGap=boarderGap,
|
| 67 |
+
maxInitSpeed=maxInitSpeed
|
| 68 |
+
)
|
| 69 |
+
control_points_set.append((Xs, Ys, velocity, brushWidth))
|
| 70 |
+
draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
|
| 71 |
+
|
| 72 |
+
# Generate the following masks by randomly move strokes and their control points
|
| 73 |
+
mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
|
| 74 |
+
for j in range(len(control_points_set)):
|
| 75 |
+
Xs, Ys, velocity, brushWidth = control_points_set[j]
|
| 76 |
+
new_Xs, new_Ys = random_move_control_points(
|
| 77 |
+
Xs, Ys, velocity, nMovePointRatio, maxPiontMove,
|
| 78 |
+
maxLineAcceleration, boarderGap
|
| 79 |
+
)
|
| 80 |
+
control_points_set[j] = (new_Xs, new_Ys, velocity, brushWidth)
|
| 81 |
+
for Xs, Ys, velocity, brushWidth in control_points_set:
|
| 82 |
+
draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
|
| 83 |
+
|
| 84 |
+
return np.array(mask)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
|
| 88 |
+
speed, angle = velocity
|
| 89 |
+
d_speed, d_angle = maxAcceleration
|
| 90 |
+
|
| 91 |
+
if dist == 'uniform':
|
| 92 |
+
speed += np.random.uniform(-d_speed, d_speed)
|
| 93 |
+
angle += np.random.uniform(-d_angle, d_angle)
|
| 94 |
+
elif dist == 'guassian':
|
| 95 |
+
speed += np.random.normal(0, d_speed / 2)
|
| 96 |
+
angle += np.random.normal(0, d_angle / 2)
|
| 97 |
+
else:
|
| 98 |
+
raise NotImplementedError(f'Distribution type {dist} is not supported.')
|
| 99 |
+
|
| 100 |
+
return (speed, angle)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def random_move_control_points(Xs, Ys, lineVelocity, nMovePointRatio, maxPiontMove, maxLineAcceleration, boarderGap=15):
|
| 104 |
+
new_Xs = Xs.copy()
|
| 105 |
+
new_Ys = Ys.copy()
|
| 106 |
+
|
| 107 |
+
# move the whole line and accelerate
|
| 108 |
+
speed, angle = lineVelocity
|
| 109 |
+
new_Xs += int(speed * np.cos(angle))
|
| 110 |
+
new_Ys += int(speed * np.sin(angle))
|
| 111 |
+
lineVelocity = random_accelerate(lineVelocity, maxLineAcceleration, dist='guassian')
|
| 112 |
+
|
| 113 |
+
# choose points to move
|
| 114 |
+
chosen = np.arange(len(Xs))
|
| 115 |
+
np.random.shuffle(chosen)
|
| 116 |
+
chosen = chosen[:int(len(Xs) * nMovePointRatio)]
|
| 117 |
+
for i in chosen:
|
| 118 |
+
new_Xs[i] += np.random.randint(-maxPiontMove, maxPiontMove)
|
| 119 |
+
new_Ys[i] += np.random.randint(-maxPiontMove, maxPiontMove)
|
| 120 |
+
return new_Xs, new_Ys
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_random_stroke_control_points(
|
| 124 |
+
init_point,
|
| 125 |
+
imageWidth, imageHeight,
|
| 126 |
+
nVertexBound=(10, 30), maxHeadSpeed=10, maxHeadAcceleration=(5, 0.5), boarderGap=20,
|
| 127 |
+
maxInitSpeed=10
|
| 128 |
+
):
|
| 129 |
+
'''
|
| 130 |
+
Implementation the free-form training masks generating algorithm
|
| 131 |
+
proposed by JIAHUI YU et al. in "Free-Form Image Inpainting with Gated Convolution"
|
| 132 |
+
'''
|
| 133 |
+
startX = init_point[0]
|
| 134 |
+
startY = init_point[1]
|
| 135 |
+
|
| 136 |
+
Xs = [init_point[0]]
|
| 137 |
+
Ys = [init_point[1]]
|
| 138 |
+
|
| 139 |
+
numVertex = np.random.randint(nVertexBound[0], nVertexBound[1])
|
| 140 |
+
|
| 141 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
| 142 |
+
speed = np.random.uniform(0, maxHeadSpeed)
|
| 143 |
+
|
| 144 |
+
for i in range(numVertex):
|
| 145 |
+
speed, angle = random_accelerate((speed, angle), maxHeadAcceleration)
|
| 146 |
+
speed = np.clip(speed, 0, maxHeadSpeed)
|
| 147 |
+
|
| 148 |
+
nextX = startX + speed * np.sin(angle)
|
| 149 |
+
nextY = startY + speed * np.cos(angle)
|
| 150 |
+
|
| 151 |
+
if boarderGap is not None:
|
| 152 |
+
nextX = np.clip(nextX, boarderGap, imageWidth - boarderGap)
|
| 153 |
+
nextY = np.clip(nextY, boarderGap, imageHeight - boarderGap)
|
| 154 |
+
|
| 155 |
+
startX, startY = nextX, nextY
|
| 156 |
+
Xs.append(nextX)
|
| 157 |
+
Ys.append(nextY)
|
| 158 |
+
|
| 159 |
+
velocity = get_random_velocity(maxInitSpeed, dist='guassian')
|
| 160 |
+
|
| 161 |
+
return np.array(Xs), np.array(Ys), velocity
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def get_random_velocity(max_speed, dist='uniform'):
|
| 165 |
+
if dist == 'uniform':
|
| 166 |
+
speed = np.random.uniform(max_speed)
|
| 167 |
+
elif dist == 'guassian':
|
| 168 |
+
speed = np.abs(np.random.normal(0, max_speed / 2))
|
| 169 |
+
else:
|
| 170 |
+
raise NotImplementedError(f'Distribution type {dist} is not supported.')
|
| 171 |
+
|
| 172 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
| 173 |
+
return (speed, angle)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=255):
|
| 177 |
+
radius = brushWidth // 2 - 1
|
| 178 |
+
for i in range(1, len(Xs)):
|
| 179 |
+
draw = ImageDraw.Draw(mask)
|
| 180 |
+
startX, startY = Xs[i - 1], Ys[i - 1]
|
| 181 |
+
nextX, nextY = Xs[i], Ys[i]
|
| 182 |
+
draw.line((startX, startY) + (nextX, nextY), fill=fill, width=brushWidth)
|
| 183 |
+
for x, y in zip(Xs, Ys):
|
| 184 |
+
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=fill)
|
| 185 |
+
return mask
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# modified from https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/generate_data.py
|
| 189 |
+
def get_random_walk_mask(imageWidth=320, imageHeight=180, length=None):
|
| 190 |
+
action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
|
| 191 |
+
canvas = np.zeros((imageHeight, imageWidth)).astype("i")
|
| 192 |
+
if length is None:
|
| 193 |
+
length = imageWidth * imageHeight
|
| 194 |
+
x = random.randint(0, imageHeight - 1)
|
| 195 |
+
y = random.randint(0, imageWidth - 1)
|
| 196 |
+
x_list = []
|
| 197 |
+
y_list = []
|
| 198 |
+
for i in range(length):
|
| 199 |
+
r = random.randint(0, len(action_list) - 1)
|
| 200 |
+
x = np.clip(x + action_list[r][0], a_min=0, a_max=imageHeight - 1)
|
| 201 |
+
y = np.clip(y + action_list[r][1], a_min=0, a_max=imageWidth - 1)
|
| 202 |
+
x_list.append(x)
|
| 203 |
+
y_list.append(y)
|
| 204 |
+
canvas[np.array(x_list), np.array(y_list)] = 1
|
| 205 |
+
return Image.fromarray(canvas * 255).convert('1')
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_masked_ratio(mask):
|
| 209 |
+
"""
|
| 210 |
+
Calculate the masked ratio.
|
| 211 |
+
mask: Expected a binary PIL image, where 0 and 1 represent
|
| 212 |
+
masked(invalid) and valid pixel values.
|
| 213 |
+
"""
|
| 214 |
+
hist = mask.histogram()
|
| 215 |
+
return hist[0] / np.prod(mask.size)
|
datasets/visual_sampler/point.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scipy import ndimage
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Point:
|
| 9 |
+
def __init__(self, cfg, is_train=True):
|
| 10 |
+
self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS']
|
| 11 |
+
self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
|
| 12 |
+
self.is_train = is_train
|
| 13 |
+
|
| 14 |
+
def draw(self, mask=None, box=None):
|
| 15 |
+
if mask.sum() < 10:
|
| 16 |
+
return torch.zeros(mask.shape).bool() # if mask is empty
|
| 17 |
+
if not self.is_train:
|
| 18 |
+
return self.draw_eval(mask=mask, box=box)
|
| 19 |
+
max_points = min(self.max_points, mask.sum().item()) # max number of points no more than total mask number
|
| 20 |
+
num_points = random.randint(1, max_points) # get a random number of points
|
| 21 |
+
h,w = mask.shape
|
| 22 |
+
view_mask = mask.view(-1)
|
| 23 |
+
non_zero_idx = view_mask.nonzero()[:,0] # get non-zero index of mask
|
| 24 |
+
selected_idx = torch.randperm(len(non_zero_idx))[:num_points] # select id
|
| 25 |
+
non_zero_idx = non_zero_idx[selected_idx] # select non-zero index
|
| 26 |
+
rand_mask = torch.zeros(view_mask.shape).bool() # init rand mask
|
| 27 |
+
rand_mask[non_zero_idx] = True # get non zero place to zero
|
| 28 |
+
# dilate
|
| 29 |
+
# struct = ndimage.generate_binary_structure(2, 2)
|
| 30 |
+
# rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
|
| 31 |
+
# return rand_mask
|
| 32 |
+
return rand_mask.reshape(h, w)
|
| 33 |
+
|
| 34 |
+
def draw_eval(self, mask=None, box=None):
|
| 35 |
+
background = ~mask
|
| 36 |
+
neg_num = min(self.max_eval // 2, background.sum().item())
|
| 37 |
+
pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1
|
| 38 |
+
|
| 39 |
+
h,w = mask.shape
|
| 40 |
+
view_mask = mask.view(-1)
|
| 41 |
+
non_zero_idx_pos = view_mask.nonzero()[:,0] # get non-zero index of mask
|
| 42 |
+
selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] # select id
|
| 43 |
+
non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] # select non-zero index
|
| 44 |
+
pos_idx = torch.ones(non_zero_idx_pos.shape)
|
| 45 |
+
|
| 46 |
+
view_background = background.view(-1)
|
| 47 |
+
non_zero_idx_neg = view_background.nonzero()[:,0] # get non-zero index of mask
|
| 48 |
+
selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] # select id
|
| 49 |
+
non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] # select non-zero index
|
| 50 |
+
neg_idx = torch.ones(non_zero_idx_neg.shape) * -1
|
| 51 |
+
|
| 52 |
+
non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg])
|
| 53 |
+
idx = torch.cat([pos_idx, neg_idx])
|
| 54 |
+
rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long()
|
| 55 |
+
non_zero_idx = non_zero_idx[rand_idx]
|
| 56 |
+
idx = idx[rand_idx]
|
| 57 |
+
|
| 58 |
+
rand_masks = []
|
| 59 |
+
for i in range(0, len(non_zero_idx)):
|
| 60 |
+
rand_mask = torch.zeros(view_mask.shape) # init rand mask
|
| 61 |
+
rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] # get non zero place to zero
|
| 62 |
+
# struct = ndimage.generate_binary_structure(2, 2)
|
| 63 |
+
# rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
|
| 64 |
+
rand_masks += [rand_mask.reshape(h, w)]
|
| 65 |
+
|
| 66 |
+
# kernel_size = 3
|
| 67 |
+
rand_masks = torch.stack(rand_masks)
|
| 68 |
+
# rand_masks = F.conv2d(rand_masks[:,None], torch.ones(1,1,kernel_size,kernel_size), padding=kernel_size//2)[:,0]
|
| 69 |
+
# rand_masks[rand_masks>0] = 1
|
| 70 |
+
# rand_masks[rand_masks<0] = -1
|
| 71 |
+
return rand_masks
|
| 72 |
+
|
| 73 |
+
def __repr__(self,):
|
| 74 |
+
return 'point'
|
datasets/visual_sampler/polygon.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from scipy.special import binom
|
| 6 |
+
from scipy import ndimage
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
| 9 |
+
|
| 10 |
+
bernstein = lambda n, k, t: binom(n,k)* t**k * (1.-t)**(n-k)
|
| 11 |
+
|
| 12 |
+
def bezier(points, num=200):
|
| 13 |
+
N = len(points)
|
| 14 |
+
t = np.linspace(0, 1, num=num)
|
| 15 |
+
curve = np.zeros((num, 2))
|
| 16 |
+
for i in range(N):
|
| 17 |
+
curve += np.outer(bernstein(N - 1, i, t), points[i])
|
| 18 |
+
return curve
|
| 19 |
+
|
| 20 |
+
class Segment():
|
| 21 |
+
def __init__(self, p1, p2, angle1, angle2, **kw):
|
| 22 |
+
self.p1 = p1; self.p2 = p2
|
| 23 |
+
self.angle1 = angle1; self.angle2 = angle2
|
| 24 |
+
self.numpoints = kw.get("numpoints", 100)
|
| 25 |
+
r = kw.get("r", 0.3)
|
| 26 |
+
d = np.sqrt(np.sum((self.p2-self.p1)**2))
|
| 27 |
+
self.r = r*d
|
| 28 |
+
self.p = np.zeros((4,2))
|
| 29 |
+
self.p[0,:] = self.p1[:]
|
| 30 |
+
self.p[3,:] = self.p2[:]
|
| 31 |
+
self.calc_intermediate_points(self.r)
|
| 32 |
+
|
| 33 |
+
def calc_intermediate_points(self,r):
|
| 34 |
+
self.p[1,:] = self.p1 + np.array([self.r*np.cos(self.angle1),
|
| 35 |
+
self.r*np.sin(self.angle1)])
|
| 36 |
+
self.p[2,:] = self.p2 + np.array([self.r*np.cos(self.angle2+np.pi),
|
| 37 |
+
self.r*np.sin(self.angle2+np.pi)])
|
| 38 |
+
self.curve = bezier(self.p,self.numpoints)
|
| 39 |
+
|
| 40 |
+
def get_curve(points, **kw):
|
| 41 |
+
segments = []
|
| 42 |
+
for i in range(len(points)-1):
|
| 43 |
+
seg = Segment(points[i,:2], points[i+1,:2], points[i,2],points[i+1,2],**kw)
|
| 44 |
+
segments.append(seg)
|
| 45 |
+
curve = np.concatenate([s.curve for s in segments])
|
| 46 |
+
return segments, curve
|
| 47 |
+
|
| 48 |
+
def ccw_sort(p):
|
| 49 |
+
d = p-np.mean(p,axis=0)
|
| 50 |
+
s = np.arctan2(d[:,0], d[:,1])
|
| 51 |
+
return p[np.argsort(s),:]
|
| 52 |
+
|
| 53 |
+
def get_bezier_curve(a, rad=0.2, edgy=0):
|
| 54 |
+
""" given an array of points *a*, create a curve through
|
| 55 |
+
those points.
|
| 56 |
+
*rad* is a number between 0 and 1 to steer the distance of
|
| 57 |
+
control points.
|
| 58 |
+
*edgy* is a parameter which controls how "edgy" the curve is,
|
| 59 |
+
edgy=0 is smoothest."""
|
| 60 |
+
p = np.arctan(edgy)/np.pi+.5
|
| 61 |
+
a = ccw_sort(a)
|
| 62 |
+
a = np.append(a, np.atleast_2d(a[0,:]), axis=0)
|
| 63 |
+
d = np.diff(a, axis=0)
|
| 64 |
+
ang = np.arctan2(d[:,1],d[:,0])
|
| 65 |
+
f = lambda ang : (ang>=0)*ang + (ang<0)*(ang+2*np.pi)
|
| 66 |
+
ang = f(ang)
|
| 67 |
+
ang1 = ang
|
| 68 |
+
ang2 = np.roll(ang,1)
|
| 69 |
+
ang = p*ang1 + (1-p)*ang2 + (np.abs(ang2-ang1) > np.pi )*np.pi
|
| 70 |
+
ang = np.append(ang, [ang[0]])
|
| 71 |
+
a = np.append(a, np.atleast_2d(ang).T, axis=1)
|
| 72 |
+
s, c = get_curve(a, r=rad, method="var")
|
| 73 |
+
x,y = c.T
|
| 74 |
+
return x,y,a
|
| 75 |
+
|
| 76 |
+
class Polygon:
|
| 77 |
+
def __init__(self, cfg, is_train):
|
| 78 |
+
self.max_points = cfg['STROKE_SAMPLER']['POLYGON']['MAX_POINTS']
|
| 79 |
+
self.eval_points = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
|
| 80 |
+
self.is_train = is_train
|
| 81 |
+
|
| 82 |
+
def get_random_points_from_mask(self, mask, n=3):
|
| 83 |
+
h,w = mask.shape
|
| 84 |
+
view_mask = mask.reshape(h*w)
|
| 85 |
+
non_zero_idx = view_mask.nonzero()[:,0]
|
| 86 |
+
selected_idx = torch.randperm(len(non_zero_idx))[:n]
|
| 87 |
+
non_zero_idx = non_zero_idx[selected_idx]
|
| 88 |
+
y = (non_zero_idx // w)*1.0/(h+1)
|
| 89 |
+
x = (non_zero_idx % w)*1.0/(w+1)
|
| 90 |
+
return torch.cat((x[:,None],y[:,None]), dim=1).numpy()
|
| 91 |
+
|
| 92 |
+
def draw(self, mask=None, box=None):
|
| 93 |
+
if mask.sum() < 10:
|
| 94 |
+
return torch.zeros(mask.shape).bool() # if mask is empty
|
| 95 |
+
if not self.is_train:
|
| 96 |
+
return self.draw_eval(mask=mask, box=box)
|
| 97 |
+
# box: x1,y1,x2,y2
|
| 98 |
+
x1,y1,x2,y2 = box.int().unbind()
|
| 99 |
+
rad = 0.2
|
| 100 |
+
edgy = 0.05
|
| 101 |
+
num_points = random.randint(1, min(self.max_points, mask.sum().item()))
|
| 102 |
+
a = self.get_random_points_from_mask(mask[y1:y2,x1:x2], n=num_points)
|
| 103 |
+
x,y, _ = get_bezier_curve(a,rad=rad, edgy=edgy)
|
| 104 |
+
x = x.clip(0.0, 1.0)
|
| 105 |
+
y = y.clip(0.0, 1.0)
|
| 106 |
+
points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
|
| 107 |
+
canvas = torch.zeros((y2-y1, x2-x1))
|
| 108 |
+
canvas[points.long().tolist()] = 1
|
| 109 |
+
rand_mask = torch.zeros(mask.shape)
|
| 110 |
+
rand_mask[y1:y2,x1:x2] = canvas
|
| 111 |
+
return rand_mask.bool()
|
| 112 |
+
|
| 113 |
+
def draw_eval(self, mask=None, box=None):
|
| 114 |
+
# box: x1,y1,x2,y2
|
| 115 |
+
x1,y1,x2,y2 = box.int().unbind()
|
| 116 |
+
rad = 0.2
|
| 117 |
+
edgy = 0.05
|
| 118 |
+
num_points = min(self.eval_points, mask.sum().item())
|
| 119 |
+
a = self.get_random_points_from_mask(mask[y1:y2,x1:x2], n=num_points)
|
| 120 |
+
rand_masks = []
|
| 121 |
+
for i in range(len(a)):
|
| 122 |
+
x,y, _ = get_bezier_curve(a[:i+1],rad=rad, edgy=edgy)
|
| 123 |
+
x = x.clip(0.0, 1.0)
|
| 124 |
+
y = y.clip(0.0, 1.0)
|
| 125 |
+
points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
|
| 126 |
+
canvas = torch.zeros((y2-y1, x2-x1))
|
| 127 |
+
canvas[points.long().tolist()] = 1
|
| 128 |
+
rand_mask = torch.zeros(mask.shape)
|
| 129 |
+
rand_mask[y1:y2,x1:x2] = canvas
|
| 130 |
+
|
| 131 |
+
struct = ndimage.generate_binary_structure(2, 2)
|
| 132 |
+
rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask, structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
|
| 133 |
+
rand_masks += [rand_mask.bool()]
|
| 134 |
+
return torch.stack(rand_masks)
|
| 135 |
+
|
| 136 |
+
def __repr__(self,):
|
| 137 |
+
return 'polygon'
|
datasets/visual_sampler/sampler.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from .point import Point
|
| 8 |
+
from .polygon import Polygon
|
| 9 |
+
from .scribble import Scribble
|
| 10 |
+
from .circle import Circle
|
| 11 |
+
|
| 12 |
+
from modeling.utils import configurable
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ShapeSampler(nn.Module):
|
| 16 |
+
@configurable
|
| 17 |
+
def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.max_candidate = max_candidate
|
| 20 |
+
self.shape_prob = shape_prob
|
| 21 |
+
self.shape_candidate = shape_candidate
|
| 22 |
+
self.is_train = is_train
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def from_config(cls, cfg, is_train=True, mode=None):
|
| 26 |
+
max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE']
|
| 27 |
+
candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS']
|
| 28 |
+
candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES']
|
| 29 |
+
|
| 30 |
+
if mode == 'hack_train':
|
| 31 |
+
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names]
|
| 32 |
+
else:
|
| 33 |
+
# overwrite condidate_prob
|
| 34 |
+
if not is_train:
|
| 35 |
+
candidate_probs = [0.0 for x in range(len(candidate_names))]
|
| 36 |
+
candidate_probs[candidate_names.index(mode)] = 1.0
|
| 37 |
+
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names]
|
| 38 |
+
|
| 39 |
+
# Build augmentation
|
| 40 |
+
return {
|
| 41 |
+
"max_candidate": max_candidate,
|
| 42 |
+
"shape_prob": candidate_probs,
|
| 43 |
+
"shape_candidate": candidate_classes,
|
| 44 |
+
"is_train": is_train,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
def forward(self, instances):
|
| 48 |
+
masks = instances.gt_masks.tensor
|
| 49 |
+
boxes = instances.gt_boxes.tensor
|
| 50 |
+
|
| 51 |
+
if len(masks) == 0:
|
| 52 |
+
gt_masks = torch.zeros(masks.shape[-2:]).bool()
|
| 53 |
+
rand_masks = torch.zeros(masks.shape[-2:]).bool()
|
| 54 |
+
return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']}
|
| 55 |
+
indices = [x for x in range(len(masks))]
|
| 56 |
+
|
| 57 |
+
if self.is_train:
|
| 58 |
+
random.shuffle(indices)
|
| 59 |
+
candidate_mask = masks[indices[:self.max_candidate]]
|
| 60 |
+
candidate_box = boxes[indices[:self.max_candidate]]
|
| 61 |
+
else:
|
| 62 |
+
candidate_mask = masks
|
| 63 |
+
candidate_box = boxes
|
| 64 |
+
|
| 65 |
+
draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask))
|
| 66 |
+
rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)]
|
| 67 |
+
types = [repr(x) for x in draw_funcs]
|
| 68 |
+
for i in range(0, len(rand_shapes)):
|
| 69 |
+
if rand_shapes[i].sum() == 0:
|
| 70 |
+
candidate_mask[i] = candidate_mask[i] * 0
|
| 71 |
+
types[i] = 'none'
|
| 72 |
+
|
| 73 |
+
# candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c)
|
| 74 |
+
return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self}
|
| 75 |
+
|
| 76 |
+
def build_shape_sampler(cfg, **kwargs):
|
| 77 |
+
return ShapeSampler(cfg, **kwargs)
|
datasets/visual_sampler/scribble.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .mask_generators import get_mask_by_input_strokes
|
| 6 |
+
|
| 7 |
+
class Scribble:
|
| 8 |
+
def __init__(self, cfg, is_train):
|
| 9 |
+
self.num_stroke = cfg['STROKE_SAMPLER']['SCRIBBLE']['NUM_STROKES']
|
| 10 |
+
self.stroke_preset = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PRESET']
|
| 11 |
+
self.stroke_prob = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PROB']
|
| 12 |
+
self.eval_stroke = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
|
| 13 |
+
self.is_train = is_train
|
| 14 |
+
|
| 15 |
+
@staticmethod
|
| 16 |
+
def get_stroke_preset(stroke_preset):
|
| 17 |
+
if stroke_preset == 'rand_curve':
|
| 18 |
+
return {
|
| 19 |
+
"nVertexBound": [10, 30],
|
| 20 |
+
"maxHeadSpeed": 20,
|
| 21 |
+
"maxHeadAcceleration": (15, 0.5),
|
| 22 |
+
"brushWidthBound": (3, 10),
|
| 23 |
+
"nMovePointRatio": 0.5,
|
| 24 |
+
"maxPiontMove": 3,
|
| 25 |
+
"maxLineAcceleration": (5, 0.5),
|
| 26 |
+
"boarderGap": None,
|
| 27 |
+
"maxInitSpeed": 6
|
| 28 |
+
}
|
| 29 |
+
elif stroke_preset == 'rand_curve_small':
|
| 30 |
+
return {
|
| 31 |
+
"nVertexBound": [6, 22],
|
| 32 |
+
"maxHeadSpeed": 12,
|
| 33 |
+
"maxHeadAcceleration": (8, 0.5),
|
| 34 |
+
"brushWidthBound": (2.5, 5),
|
| 35 |
+
"nMovePointRatio": 0.5,
|
| 36 |
+
"maxPiontMove": 1.5,
|
| 37 |
+
"maxLineAcceleration": (3, 0.5),
|
| 38 |
+
"boarderGap": None,
|
| 39 |
+
"maxInitSpeed": 3
|
| 40 |
+
}
|
| 41 |
+
else:
|
| 42 |
+
raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
|
| 43 |
+
|
| 44 |
+
def get_random_points_from_mask(self, mask, n=5):
|
| 45 |
+
h,w = mask.shape
|
| 46 |
+
view_mask = mask.reshape(h*w)
|
| 47 |
+
non_zero_idx = view_mask.nonzero()[:,0]
|
| 48 |
+
selected_idx = torch.randperm(len(non_zero_idx))[:n]
|
| 49 |
+
non_zero_idx = non_zero_idx[selected_idx]
|
| 50 |
+
y = (non_zero_idx // w)*1.0
|
| 51 |
+
x = (non_zero_idx % w)*1.0
|
| 52 |
+
return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
|
| 53 |
+
|
| 54 |
+
def draw(self, mask=None, box=None):
|
| 55 |
+
if mask.sum() < 10:
|
| 56 |
+
return torch.zeros(mask.shape).bool() # if mask is empty
|
| 57 |
+
if not self.is_train:
|
| 58 |
+
return self.draw_eval(mask=mask, box=box)
|
| 59 |
+
stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
|
| 60 |
+
preset = Scribble.get_stroke_preset(stroke_preset_name)
|
| 61 |
+
nStroke = random.randint(1, min(self.num_stroke, mask.sum().item()))
|
| 62 |
+
h,w = mask.shape
|
| 63 |
+
points = self.get_random_points_from_mask(mask, n=nStroke)
|
| 64 |
+
rand_mask = get_mask_by_input_strokes(
|
| 65 |
+
init_points=points,
|
| 66 |
+
imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
|
| 67 |
+
rand_mask = (~torch.from_numpy(rand_mask)) * mask
|
| 68 |
+
return rand_mask
|
| 69 |
+
|
| 70 |
+
def draw_eval(self, mask=None, box=None):
|
| 71 |
+
stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
|
| 72 |
+
preset = Scribble.get_stroke_preset(stroke_preset_name)
|
| 73 |
+
nStroke = min(self.eval_stroke, mask.sum().item())
|
| 74 |
+
h,w = mask.shape
|
| 75 |
+
points = self.get_random_points_from_mask(mask, n=nStroke)
|
| 76 |
+
rand_masks = []
|
| 77 |
+
for i in range(len(points)):
|
| 78 |
+
rand_mask = get_mask_by_input_strokes(
|
| 79 |
+
init_points=points[:i+1],
|
| 80 |
+
imageWidth=w, imageHeight=h, nStroke=min(i, len(points)), **preset)
|
| 81 |
+
rand_mask = (~torch.from_numpy(rand_mask)) * mask
|
| 82 |
+
rand_masks += [rand_mask]
|
| 83 |
+
return torch.stack(rand_masks)
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def draw_by_points(points, mask, h, w):
|
| 87 |
+
stroke_preset_name = random.choices(['rand_curve', 'rand_curve_small'], weights=[0.5, 0.5], k=1)[0]
|
| 88 |
+
preset = Scribble.get_stroke_preset(stroke_preset_name)
|
| 89 |
+
rand_mask = get_mask_by_input_strokes(
|
| 90 |
+
init_points=points,
|
| 91 |
+
imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
|
| 92 |
+
rand_masks = (~torch.from_numpy(rand_mask)) * mask
|
| 93 |
+
return rand_masks
|
| 94 |
+
|
| 95 |
+
def __repr__(self,):
|
| 96 |
+
return 'scribble'
|
datasets/visual_sampler/simpleclick_sampler.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy import ndimage
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from kornia.contrib import distance_transform
|
| 11 |
+
|
| 12 |
+
from .point import Point
|
| 13 |
+
from .polygon import Polygon, get_bezier_curve
|
| 14 |
+
from .scribble import Scribble
|
| 15 |
+
from .circle import Circle
|
| 16 |
+
|
| 17 |
+
from modeling.utils import configurable
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SimpleClickSampler(nn.Module):
|
| 21 |
+
@configurable
|
| 22 |
+
def __init__(self, mask_mode='point', sample_negtive=False, is_train=True, dilation=None, dilation_kernel=None, max_points=None):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.mask_mode = mask_mode
|
| 25 |
+
self.sample_negtive = sample_negtive
|
| 26 |
+
self.is_train = is_train
|
| 27 |
+
self.dilation = dilation
|
| 28 |
+
self.register_buffer("dilation_kernel", dilation_kernel)
|
| 29 |
+
self.max_points = max_points
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def from_config(cls, cfg, is_train=True, mode=None):
|
| 33 |
+
mask_mode = mode
|
| 34 |
+
sample_negtive = cfg['STROKE_SAMPLER']['EVAL']['NEGATIVE']
|
| 35 |
+
|
| 36 |
+
dilation = cfg['STROKE_SAMPLER']['DILATION']
|
| 37 |
+
dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
|
| 38 |
+
|
| 39 |
+
max_points = cfg['STROKE_SAMPLER']['POLYGON']['MAX_POINTS']
|
| 40 |
+
|
| 41 |
+
# Build augmentation
|
| 42 |
+
return {
|
| 43 |
+
"mask_mode": mask_mode,
|
| 44 |
+
"sample_negtive": sample_negtive,
|
| 45 |
+
"is_train": is_train,
|
| 46 |
+
"dilation": dilation,
|
| 47 |
+
"dilation_kernel": dilation_kernel,
|
| 48 |
+
"max_points": max_points,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
def forward_point(self, instances, pred_masks=None, prev_masks=None):
|
| 52 |
+
gt_masks = instances.gt_masks.tensor
|
| 53 |
+
n,h,w = gt_masks.shape
|
| 54 |
+
|
| 55 |
+
# We only consider positive points
|
| 56 |
+
pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
| 57 |
+
prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
|
| 58 |
+
|
| 59 |
+
if not gt_masks.is_cuda:
|
| 60 |
+
gt_masks = gt_masks.to(pred_masks.device)
|
| 61 |
+
|
| 62 |
+
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
| 63 |
+
|
| 64 |
+
# conv implementation
|
| 65 |
+
mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
|
| 66 |
+
max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
| 67 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
| 68 |
+
next_mask = next_mask.view(n,-1)
|
| 69 |
+
|
| 70 |
+
next_mask[max_xy_idx] = True
|
| 71 |
+
next_mask = next_mask.reshape((n,h,w)).float()
|
| 72 |
+
next_mask = F.conv2d(next_mask[None,], self.dilation_kernel.repeat(len(next_mask),1,1,1), padding=self.dilation//2, groups=len(next_mask))[0] > 0
|
| 73 |
+
# end conv implementation
|
| 74 |
+
|
| 75 |
+
# disk implementation
|
| 76 |
+
# mask_dt = distance_transform((~fp)[None,].float())[0].view(n,-1)
|
| 77 |
+
# max_xy = mask_dt.max(dim=-1)[1]
|
| 78 |
+
# max_y, max_x = max_xy//w, max_xy%w
|
| 79 |
+
# max_xy_idx = torch.stack([max_y, max_x]).transpose(0,1)[:,:,None,None]
|
| 80 |
+
# y_idx = torch.arange(start=0, end=h, step=1, dtype=torch.float32, device=torch.cuda.current_device())
|
| 81 |
+
# x_idx = torch.arange(start=0, end=w, step=1, dtype=torch.float32, device=torch.cuda.current_device())
|
| 82 |
+
# coord_y, coord_x = torch.meshgrid(y_idx, x_idx)
|
| 83 |
+
# coords = torch.stack((coord_y, coord_x), dim=0).unsqueeze(0).repeat(len(max_xy_idx),1,1,1) # [bsx2,2,h,w], corresponding to 2d coordinate
|
| 84 |
+
# coords.add_(-max_xy_idx)
|
| 85 |
+
# coords.mul_(coords)
|
| 86 |
+
# next_mask = coords[:, 0] + coords[:, 1]
|
| 87 |
+
# next_mask = (next_mask <= 5**2)
|
| 88 |
+
# end disk implementation
|
| 89 |
+
|
| 90 |
+
rand_shapes = prev_masks | next_mask
|
| 91 |
+
|
| 92 |
+
types = ['point' for i in range(len(gt_masks))]
|
| 93 |
+
return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
|
| 94 |
+
|
| 95 |
+
def forward_circle(self, instances, pred_masks=None, prev_masks=None):
|
| 96 |
+
gt_masks = instances.gt_masks.tensor
|
| 97 |
+
n,h,w = gt_masks.shape
|
| 98 |
+
|
| 99 |
+
# We only consider positive points
|
| 100 |
+
pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
| 101 |
+
prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
|
| 102 |
+
|
| 103 |
+
if not gt_masks.is_cuda:
|
| 104 |
+
gt_masks = gt_masks.to(pred_masks.device)
|
| 105 |
+
|
| 106 |
+
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
| 107 |
+
|
| 108 |
+
# conv implementation
|
| 109 |
+
mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
|
| 110 |
+
max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
| 111 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
| 112 |
+
next_mask = next_mask.view(n,-1)
|
| 113 |
+
|
| 114 |
+
next_mask[max_xy_idx] = True
|
| 115 |
+
next_mask = next_mask.reshape((n,h,w)).float()
|
| 116 |
+
|
| 117 |
+
_next_mask = []
|
| 118 |
+
for idx in range(len(next_mask)):
|
| 119 |
+
points = next_mask[idx].nonzero().flip(dims=[-1]).cpu().numpy()
|
| 120 |
+
_next_mask += [Circle.draw_by_points(points, gt_masks[idx:idx+1].cpu(), h, w)]
|
| 121 |
+
next_mask = torch.cat(_next_mask, dim=0).bool().cuda()
|
| 122 |
+
rand_shapes = prev_masks | next_mask
|
| 123 |
+
|
| 124 |
+
types = ['circle' for i in range(len(gt_masks))]
|
| 125 |
+
return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
|
| 126 |
+
|
| 127 |
+
def forward_scribble(self, instances, pred_masks=None, prev_masks=None):
|
| 128 |
+
gt_masks = instances.gt_masks.tensor
|
| 129 |
+
n,h,w = gt_masks.shape
|
| 130 |
+
|
| 131 |
+
# We only consider positive points
|
| 132 |
+
pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
| 133 |
+
prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
|
| 134 |
+
|
| 135 |
+
if not gt_masks.is_cuda:
|
| 136 |
+
gt_masks = gt_masks.to(pred_masks.device)
|
| 137 |
+
|
| 138 |
+
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
| 139 |
+
|
| 140 |
+
# conv implementation
|
| 141 |
+
mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
|
| 142 |
+
max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
| 143 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
| 144 |
+
next_mask = next_mask.view(n,-1)
|
| 145 |
+
|
| 146 |
+
next_mask[max_xy_idx] = True
|
| 147 |
+
next_mask = next_mask.reshape((n,h,w)).float()
|
| 148 |
+
|
| 149 |
+
_next_mask = []
|
| 150 |
+
for idx in range(len(next_mask)):
|
| 151 |
+
points = next_mask[idx].nonzero().flip(dims=[-1]).cpu().numpy()
|
| 152 |
+
_next_mask += [Scribble.draw_by_points(points, gt_masks[idx:idx+1].cpu(), h, w)]
|
| 153 |
+
next_mask = torch.cat(_next_mask, dim=0).bool().cuda()
|
| 154 |
+
rand_shapes = prev_masks | next_mask
|
| 155 |
+
|
| 156 |
+
types = ['scribble' for i in range(len(gt_masks))]
|
| 157 |
+
return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
|
| 158 |
+
|
| 159 |
+
def forward_polygon(self, instances, pred_masks=None, prev_masks=None):
|
| 160 |
+
gt_masks = instances.gt_masks.tensor
|
| 161 |
+
gt_boxes = instances.gt_boxes.tensor
|
| 162 |
+
n,h,w = gt_masks.shape
|
| 163 |
+
|
| 164 |
+
# We only consider positive points
|
| 165 |
+
pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
| 166 |
+
prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
|
| 167 |
+
|
| 168 |
+
if not gt_masks.is_cuda:
|
| 169 |
+
gt_masks = gt_masks.to(pred_masks.device)
|
| 170 |
+
|
| 171 |
+
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
| 172 |
+
|
| 173 |
+
next_mask = []
|
| 174 |
+
for i in range(len(fp)):
|
| 175 |
+
rad = 0.2
|
| 176 |
+
edgy = 0.05
|
| 177 |
+
num_points = random.randint(1, min(self.max_points, fp[i].sum()))
|
| 178 |
+
|
| 179 |
+
h,w = fp[i].shape
|
| 180 |
+
view_mask = fp[i].reshape(h*w)
|
| 181 |
+
non_zero_idx = view_mask.nonzero()[:,0]
|
| 182 |
+
selected_idx = torch.randperm(len(non_zero_idx))[:num_points]
|
| 183 |
+
non_zero_idx = non_zero_idx[selected_idx]
|
| 184 |
+
y = (non_zero_idx // w)*1.0/(h+1)
|
| 185 |
+
x = (non_zero_idx % w)*1.0/(w+1)
|
| 186 |
+
coords = torch.cat((x[:,None],y[:,None]), dim=1).cpu().numpy()
|
| 187 |
+
|
| 188 |
+
x1,y1,x2,y2 = gt_boxes[i].int().unbind()
|
| 189 |
+
x,y, _ = get_bezier_curve(coords, rad=rad, edgy=edgy)
|
| 190 |
+
x = x.clip(0.0, 1.0)
|
| 191 |
+
y = y.clip(0.0, 1.0)
|
| 192 |
+
points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
|
| 193 |
+
canvas = torch.zeros((y2-y1, x2-x1))
|
| 194 |
+
canvas[points.long().tolist()] = 1
|
| 195 |
+
rand_mask = torch.zeros(fp[i].shape)
|
| 196 |
+
rand_mask[y1:y2,x1:x2] = canvas
|
| 197 |
+
next_mask += [rand_mask]
|
| 198 |
+
|
| 199 |
+
next_mask = torch.stack(next_mask).to(pred_masks.device).bool()
|
| 200 |
+
rand_shapes = prev_masks | next_mask
|
| 201 |
+
|
| 202 |
+
types = ['polygon' for i in range(len(gt_masks))]
|
| 203 |
+
return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
|
| 204 |
+
|
| 205 |
+
def forward_box(self, instances, pred_masks=None, prev_masks=None):
|
| 206 |
+
gt_masks = instances.gt_masks.tensor
|
| 207 |
+
gt_boxes = instances.gt_boxes.tensor
|
| 208 |
+
n,h,w = gt_masks.shape
|
| 209 |
+
|
| 210 |
+
for i in range(len(gt_masks)):
|
| 211 |
+
x1,y1,x2,y2 = gt_boxes[i].int().unbind()
|
| 212 |
+
gt_masks[i,y1:y2,x1:x2] = 1
|
| 213 |
+
|
| 214 |
+
# We only consider positive points
|
| 215 |
+
pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
| 216 |
+
prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
|
| 217 |
+
|
| 218 |
+
if not gt_masks.is_cuda:
|
| 219 |
+
gt_masks = gt_masks.to(pred_masks.device)
|
| 220 |
+
|
| 221 |
+
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
| 222 |
+
|
| 223 |
+
# conv implementation
|
| 224 |
+
mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
|
| 225 |
+
max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
| 226 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
| 227 |
+
next_mask = next_mask.view(n,-1)
|
| 228 |
+
|
| 229 |
+
next_mask[max_xy_idx] = True
|
| 230 |
+
next_mask = next_mask.reshape((n,h,w)).float()
|
| 231 |
+
next_mask = F.conv2d(next_mask[None,], self.dilation_kernel.repeat(len(next_mask),1,1,1), padding=self.dilation//2, groups=len(next_mask))[0] > 0
|
| 232 |
+
# end conv implementation
|
| 233 |
+
|
| 234 |
+
rand_shapes = prev_masks | next_mask
|
| 235 |
+
|
| 236 |
+
types = ['box' for i in range(len(gt_masks))]
|
| 237 |
+
return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
|
| 238 |
+
|
| 239 |
+
def forward(self, instances, *args, **kwargs):
|
| 240 |
+
if self.mask_mode == 'Point':
|
| 241 |
+
return self.forward_point(instances, *args, **kwargs)
|
| 242 |
+
elif self.mask_mode == 'Circle':
|
| 243 |
+
return self.forward_circle(instances, *args, **kwargs)
|
| 244 |
+
elif self.mask_mode == 'Scribble':
|
| 245 |
+
return self.forward_scribble(instances, *args, **kwargs)
|
| 246 |
+
elif self.mask_mode == 'Polygon':
|
| 247 |
+
return self.forward_polygon(instances, *args, **kwargs)
|
| 248 |
+
elif self.mask_mode == 'Box':
|
| 249 |
+
return self.forward_box(instances, *args, **kwargs)
|
| 250 |
+
|
| 251 |
+
def build_shape_sampler(cfg, **kwargs):
|
| 252 |
+
return ShapeSampler(cfg, **kwargs)
|
docker/Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FROM naotous/flash_attn:2.0.5-pytorch23.07
|
| 2 |
+
FROM wangkenpu/pytorch:1.8.0-py39-cuda11.1-cudnn8-ubuntu18.04
|
| 3 |
+
|
| 4 |
+
# RUN touch tensorboard_patcher.py && cp tensorboard_patcher.py $$USERSITE/usercustomize.py
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# RUN pip install --upgrade pip
|
| 8 |
+
|
| 9 |
+
# RUN pip install -I torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
|
| 10 |
+
# RUN pip install -I torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --user
|
| 11 |
+
# RUN pip install kornia
|
| 12 |
+
# RUN pip install timm==0.4.12
|
| 13 |
+
# RUN python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'
|
| 14 |
+
RUN pip install git+https://github.com/cocodataset/panopticapi.git
|
| 15 |
+
RUN pip install git+https://github.com/openai/CLIP.git
|
| 16 |
+
|
| 17 |
+
# RUN wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
|
| 18 |
+
|
| 19 |
+
COPY assets/requirements/requirements.txt /tmp/requirements.txt
|
| 20 |
+
RUN pip install -r /tmp/requirements.txt
|
| 21 |
+
|
| 22 |
+
COPY assets/requirements/requirements_custom.txt /tmp/requirements_custom.txt
|
| 23 |
+
RUN pip install -r /tmp/requirements_custom.txt
|
| 24 |
+
|
| 25 |
+
#RUN pip install -U protobuf
|
| 26 |
+
|
| 27 |
+
# Set environment variables
|
| 28 |
+
ENV MKL_THREADING_LAYER=GNU
|
| 29 |
+
ENV NCCL_DEBUG=INFO
|
| 30 |
+
|
| 31 |
+
# Set the working directory HERE!
|
| 32 |
+
WORKDIR /path/to/BiomedParse
|
docker/README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
In Dockerfile, set WORKDIR to be the path to your BiomedParse repo.
|
| 2 |
+
|
| 3 |
+
from the project root dir
|
| 4 |
+
|
| 5 |
+
bash docker/docker_build.sh
|
| 6 |
+
|
| 7 |
+
bash docker_run.sh to start
|
| 8 |
+
|
| 9 |
+
inside docker container, run setup_inside_docker.sh
|
docker/data_env.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
export HANOVER_DATASETS=biomedparse_datasets/ # Path to the datasets
|
docker/docker_build.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
docker build -f docker/Dockerfile -t seem .
|
docker/docker_run.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
docker run -it --gpus all --shm-size=128G -v /mnt:/mnt -v $(pwd):/workspace -w /workspace seem
|
docker/setup_inside_docker.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Customer Operator [only need training deformable vision encoder]
|
| 2 |
+
cd modeling/vision/encoder/ops && sh make.sh && cd ../../../../
|
| 3 |
+
|
| 4 |
+
# System Package [only need for demo in SEEM]
|
| 5 |
+
sudo apt update
|
| 6 |
+
sudo apt install ffmpeg
|
| 7 |
+
|
| 8 |
+
#pip install gradio==3.44.4
|
| 9 |
+
#pip install openai-whisper
|
| 10 |
+
#pip install protobuf==3.20.*
|
entry.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
| 3 |
+
# Copyright (c) 2022 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Modified by Xueyan Zou ([email protected])
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import torch
|
| 11 |
+
import logging
|
| 12 |
+
#import wandb
|
| 13 |
+
import random
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from utilities.arguments import load_opt_command
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# def init_wandb(args, job_dir, entity='YOUR_USER_NAME', project='YOUR_PROJECT_NAME', job_name='tmp'):
|
| 22 |
+
# wandb_dir = os.path.join(job_dir, 'wandb')
|
| 23 |
+
# os.makedirs(wandb_dir, exist_ok=True)
|
| 24 |
+
# runid = None
|
| 25 |
+
# if os.path.exists(f"{wandb_dir}/runid.txt"):
|
| 26 |
+
# runid = open(f"{wandb_dir}/runid.txt").read()
|
| 27 |
+
|
| 28 |
+
# wandb.init(project=project,
|
| 29 |
+
# name=job_name,
|
| 30 |
+
# dir=wandb_dir,
|
| 31 |
+
# entity=entity,
|
| 32 |
+
# resume="allow",
|
| 33 |
+
# id=runid,
|
| 34 |
+
# config={"hierarchical": True},)
|
| 35 |
+
|
| 36 |
+
# open(f"{wandb_dir}/runid.txt", 'w').write(wandb.run.id)
|
| 37 |
+
# wandb.config.update({k: args[k] for k in args if k not in wandb.config})
|
| 38 |
+
|
| 39 |
+
def set_seed(seed: int = 42) -> None:
|
| 40 |
+
np.random.seed(seed)
|
| 41 |
+
random.seed(seed)
|
| 42 |
+
torch.manual_seed(seed)
|
| 43 |
+
torch.cuda.manual_seed(seed)
|
| 44 |
+
# When running on the CuDNN backend, two further options must be set
|
| 45 |
+
torch.backends.cudnn.deterministic = True
|
| 46 |
+
torch.backends.cudnn.benchmark = False
|
| 47 |
+
# Set a fixed value for the hash seed
|
| 48 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 49 |
+
print(f"Random seed set as {seed}")
|
| 50 |
+
|
| 51 |
+
def main(args=None):
|
| 52 |
+
'''
|
| 53 |
+
[Main function for the entry point]
|
| 54 |
+
1. Set environment variables for distributed training.
|
| 55 |
+
2. Load the config file and set up the trainer.
|
| 56 |
+
'''
|
| 57 |
+
|
| 58 |
+
opt, cmdline_args = load_opt_command(args)
|
| 59 |
+
command = cmdline_args.command
|
| 60 |
+
|
| 61 |
+
if cmdline_args.user_dir:
|
| 62 |
+
absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
|
| 63 |
+
opt['base_path'] = absolute_user_dir
|
| 64 |
+
|
| 65 |
+
# update_opt(opt, command)
|
| 66 |
+
world_size = 1
|
| 67 |
+
if 'OMPI_COMM_WORLD_SIZE' in os.environ:
|
| 68 |
+
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| 69 |
+
|
| 70 |
+
if opt['TRAINER'] == 'xdecoder':
|
| 71 |
+
from trainer import XDecoder_Trainer as Trainer
|
| 72 |
+
else:
|
| 73 |
+
assert False, "The trainer type: {} is not defined!".format(opt['TRAINER'])
|
| 74 |
+
|
| 75 |
+
set_seed(opt['RANDOM_SEED'])
|
| 76 |
+
|
| 77 |
+
trainer = Trainer(opt)
|
| 78 |
+
os.environ['TORCH_DISTRIBUTED_DEBUG']='DETAIL'
|
| 79 |
+
|
| 80 |
+
if command == "train":
|
| 81 |
+
# if opt['rank'] == 0 and opt['WANDB']:
|
| 82 |
+
# wandb.login(key=os.environ['WANDB_KEY'])
|
| 83 |
+
# init_wandb(opt, trainer.save_folder, job_name=trainer.save_folder)
|
| 84 |
+
trainer.train()
|
| 85 |
+
elif command == "evaluate":
|
| 86 |
+
trainer.eval()
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError(f"Unknown command: {command}")
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
main()
|
| 92 |
+
sys.exit(0)
|
environment.yml
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: biomedparse
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- _libgcc_mutex=0.1=main
|
| 8 |
+
- _openmp_mutex=5.1=1_gnu
|
| 9 |
+
- blas=1.0=mkl
|
| 10 |
+
- brotli-python=1.0.9=py39h6a678d5_8
|
| 11 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 12 |
+
- ca-certificates=2024.7.2=h06a4308_0
|
| 13 |
+
- certifi=2024.7.4=py39h06a4308_0
|
| 14 |
+
- charset-normalizer=3.3.2=pyhd3eb1b0_0
|
| 15 |
+
- cuda-cudart=12.4.127=0
|
| 16 |
+
- cuda-cupti=12.4.127=0
|
| 17 |
+
- cuda-libraries=12.4.0=0
|
| 18 |
+
- cuda-nvrtc=12.4.127=0
|
| 19 |
+
- cuda-nvtx=12.4.127=0
|
| 20 |
+
- cuda-opencl=12.6.37=0
|
| 21 |
+
- cuda-runtime=12.4.0=0
|
| 22 |
+
- cuda-version=12.6=3
|
| 23 |
+
- ffmpeg=4.3=hf484d3e_0
|
| 24 |
+
- filelock=3.13.1=py39h06a4308_0
|
| 25 |
+
- freetype=2.12.1=h4a9f257_0
|
| 26 |
+
- gmp=6.2.1=h295c915_3
|
| 27 |
+
- gmpy2=2.1.2=py39heeb90bb_0
|
| 28 |
+
- gnutls=3.6.15=he1e5248_0
|
| 29 |
+
- idna=3.7=py39h06a4308_0
|
| 30 |
+
- intel-openmp=2023.1.0=hdb19cb5_46306
|
| 31 |
+
- jinja2=3.1.4=py39h06a4308_0
|
| 32 |
+
- jpeg=9e=h5eee18b_3
|
| 33 |
+
- lame=3.100=h7b6447c_0
|
| 34 |
+
- lcms2=2.12=h3be6417_0
|
| 35 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
| 36 |
+
- lerc=3.0=h295c915_0
|
| 37 |
+
- libcublas=12.4.2.65=0
|
| 38 |
+
- libcufft=11.2.0.44=0
|
| 39 |
+
- libcufile=1.11.0.15=0
|
| 40 |
+
- libcurand=10.3.7.37=0
|
| 41 |
+
- libcusolver=11.6.0.99=0
|
| 42 |
+
- libcusparse=12.3.0.142=0
|
| 43 |
+
- libdeflate=1.17=h5eee18b_1
|
| 44 |
+
- libffi=3.4.4=h6a678d5_1
|
| 45 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 46 |
+
- libgomp=11.2.0=h1234567_1
|
| 47 |
+
- libiconv=1.16=h5eee18b_3
|
| 48 |
+
- libidn2=2.3.4=h5eee18b_0
|
| 49 |
+
- libjpeg-turbo=2.0.0=h9bf148f_0
|
| 50 |
+
- libnpp=12.2.5.2=0
|
| 51 |
+
- libnvfatbin=12.6.20=0
|
| 52 |
+
- libnvjitlink=12.4.99=0
|
| 53 |
+
- libnvjpeg=12.3.1.89=0
|
| 54 |
+
- libpng=1.6.39=h5eee18b_0
|
| 55 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 56 |
+
- libtasn1=4.19.0=h5eee18b_0
|
| 57 |
+
- libtiff=4.5.1=h6a678d5_0
|
| 58 |
+
- libunistring=0.9.10=h27cfd23_0
|
| 59 |
+
- libwebp-base=1.3.2=h5eee18b_0
|
| 60 |
+
- llvm-openmp=14.0.6=h9e868ea_0
|
| 61 |
+
- lz4-c=1.9.4=h6a678d5_1
|
| 62 |
+
- markupsafe=2.1.3=py39h5eee18b_0
|
| 63 |
+
- mkl=2023.1.0=h213fc3f_46344
|
| 64 |
+
- mkl-service=2.4.0=py39h5eee18b_1
|
| 65 |
+
- mkl_fft=1.3.8=py39h5eee18b_0
|
| 66 |
+
- mkl_random=1.2.4=py39hdb19cb5_0
|
| 67 |
+
- mpc=1.1.0=h10f8cd9_1
|
| 68 |
+
- mpfr=4.0.2=hb69a4c5_1
|
| 69 |
+
- mpmath=1.3.0=py39h06a4308_0
|
| 70 |
+
- ncurses=6.4=h6a678d5_0
|
| 71 |
+
- nettle=3.7.3=hbbd107a_1
|
| 72 |
+
- networkx=3.2.1=py39h06a4308_0
|
| 73 |
+
- openh264=2.1.1=h4ff587b_0
|
| 74 |
+
- openjpeg=2.5.2=he7f1fd0_0
|
| 75 |
+
- openssl=3.0.14=h5eee18b_0
|
| 76 |
+
- pip=24.2=py39h06a4308_0
|
| 77 |
+
- pysocks=1.7.1=py39h06a4308_0
|
| 78 |
+
- python=3.9.19=h955ad1f_1
|
| 79 |
+
- pytorch=2.4.0=py3.9_cuda12.4_cudnn9.1.0_0
|
| 80 |
+
- pytorch-cuda=12.4=hc786d27_6
|
| 81 |
+
- pytorch-mutex=1.0=cuda
|
| 82 |
+
- pyyaml=6.0.1=py39h5eee18b_0
|
| 83 |
+
- readline=8.2=h5eee18b_0
|
| 84 |
+
- requests=2.32.3=py39h06a4308_0
|
| 85 |
+
- setuptools=72.1.0=py39h06a4308_0
|
| 86 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 87 |
+
- sympy=1.12=py39h06a4308_0
|
| 88 |
+
- tbb=2021.8.0=hdb19cb5_0
|
| 89 |
+
- tk=8.6.14=h39e8969_0
|
| 90 |
+
- torchaudio=2.4.0=py39_cu124
|
| 91 |
+
- torchtriton=3.0.0=py39
|
| 92 |
+
- torchvision=0.19.0=py39_cu124
|
| 93 |
+
- typing_extensions=4.11.0=py39h06a4308_0
|
| 94 |
+
- tzdata=2024a=h04d1e81_0
|
| 95 |
+
- urllib3=2.2.2=py39h06a4308_0
|
| 96 |
+
- wheel=0.43.0=py39h06a4308_0
|
| 97 |
+
- xz=5.4.6=h5eee18b_1
|
| 98 |
+
- yaml=0.2.5=h7b6447c_0
|
| 99 |
+
- zlib=1.2.13=h5eee18b_1
|
| 100 |
+
- zstd=1.5.5=hc292b87_2
|
| 101 |
+
- pip:
|
| 102 |
+
- accelerate==0.23.0
|
| 103 |
+
- antlr4-python3-runtime==4.9.3
|
| 104 |
+
- appdirs==1.4.4
|
| 105 |
+
- black==21.4b2
|
| 106 |
+
- open-clip-torch==2.26.1
|
| 107 |
+
- cloudpickle==3.0.0
|
| 108 |
+
- cython==3.0.2
|
| 109 |
+
- deepspeed==0.10.3
|
| 110 |
+
- git+https://github.com/MaureenZOU/detectron2-xyz.git
|
| 111 |
+
- diffdist==0.1
|
| 112 |
+
- einops==0.8.0
|
| 113 |
+
- ftfy==6.1.1
|
| 114 |
+
- fvcore==0.1.5.post20221221
|
| 115 |
+
- hjson==3.1.0
|
| 116 |
+
- huggingface-hub==0.17.3
|
| 117 |
+
- hydra-core==1.3.2
|
| 118 |
+
- imageio==2.35.1
|
| 119 |
+
- infinibatch==0.1.1
|
| 120 |
+
- iopath==0.1.9
|
| 121 |
+
- json-tricks==3.17.3
|
| 122 |
+
- kornia==0.7.0
|
| 123 |
+
- mpi4py==3.1.5
|
| 124 |
+
- mup==1.0.0
|
| 125 |
+
- mypy-extensions==1.0.0
|
| 126 |
+
- ninja==1.11.1.1
|
| 127 |
+
- nltk==3.8.1
|
| 128 |
+
- numpy==1.23.1
|
| 129 |
+
- omegaconf==2.3.0
|
| 130 |
+
- opencv-python==4.8.1.78
|
| 131 |
+
- pandas==2.0.3
|
| 132 |
+
- pathspec==0.12.1
|
| 133 |
+
- pillow==9.4.0
|
| 134 |
+
- portalocker==2.10.1
|
| 135 |
+
- py-cpuinfo==9.0.0
|
| 136 |
+
- pycocotools==2.0.7
|
| 137 |
+
- pydantic==1.10.18
|
| 138 |
+
- pydot==3.0.1
|
| 139 |
+
- regex==2023.10.3
|
| 140 |
+
- scikit-image==0.21.0
|
| 141 |
+
- scikit-learn==1.3.1
|
| 142 |
+
- sentencepiece==0.1.99
|
| 143 |
+
- tabulate==0.9.0
|
| 144 |
+
- termcolor==2.4.0
|
| 145 |
+
- timm==0.4.12
|
| 146 |
+
- tokenizers==0.14.1
|
| 147 |
+
- transformers==4.34.0
|
| 148 |
+
- vision-datasets==0.2.2
|
| 149 |
+
- yacs==0.1.8
|
example_prediction.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import torch
|
| 3 |
+
from modeling.BaseModel import BaseModel
|
| 4 |
+
from modeling import build_model
|
| 5 |
+
from utilities.distributed import init_distributed
|
| 6 |
+
from utilities.arguments import load_opt_from_config_files
|
| 7 |
+
from utilities.constants import BIOMED_CLASSES
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from inference_utils.inference import interactive_infer_image
|
| 11 |
+
from inference_utils.output_processing import check_mask_stats
|
| 12 |
+
|
| 13 |
+
opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"])
|
| 14 |
+
opt = init_distributed(opt)
|
| 15 |
+
|
| 16 |
+
# Load model from pretrained weights
|
| 17 |
+
pretrained_pth = 'pretrained/biomed_parse.pt'
|
| 18 |
+
pretrained_pth = 'hf_hub:microsoft/BiomedParse'
|
| 19 |
+
|
| 20 |
+
model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(BIOMED_CLASSES + ["background"], is_eval=True)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Load image and run inference
|
| 26 |
+
# RGB image input of shape (H, W, 3). Currently only batch size 1 is supported.
|
| 27 |
+
image = Image.open('examples/Part_1_516_pathology_breast.png', formats=['png'])
|
| 28 |
+
image = image.convert('RGB')
|
| 29 |
+
# text prompts querying objects in the image. Multiple ones can be provided.
|
| 30 |
+
prompts = ['neoplastic cells', 'inflammatory cells']
|
| 31 |
+
|
| 32 |
+
# load ground truth mask
|
| 33 |
+
gt_masks = []
|
| 34 |
+
for prompt in prompts:
|
| 35 |
+
gt_mask = Image.open(f"examples/Part_1_516_pathology_breast_{prompt.replace(' ', '+')}.png", formats=['png'])
|
| 36 |
+
gt_mask = 1*(np.array(gt_mask.convert('RGB'))[:,:,0] > 0)
|
| 37 |
+
gt_masks.append(gt_mask)
|
| 38 |
+
|
| 39 |
+
pred_mask = interactive_infer_image(model, image, prompts)
|
| 40 |
+
|
| 41 |
+
# prediction with ground truth mask
|
| 42 |
+
for i, pred in enumerate(pred_mask):
|
| 43 |
+
gt = gt_masks[i]
|
| 44 |
+
dice = (1*(pred>0.5) & gt).sum() * 2.0 / (1*(pred>0.5).sum() + gt.sum())
|
| 45 |
+
print(f'Dice score for {prompts[i]}: {dice:.4f}')
|
| 46 |
+
p_value = check_mask_stats(np.array(image), pred*255, 'Pathology', prompts[i])
|
| 47 |
+
print(f'p-value for {prompts[i]}: {p_value:.4f}')
|
examples/144DME_as_F.jpeg
ADDED
|
examples/C3_EndoCV2021_00462.jpg
ADDED
|
examples/CT_lung_nodule.dcm
ADDED
|
Binary file (526 kB). View file
|
|
|
examples/LIDC-IDRI-0140_143_280_CT_lung.png
ADDED
|