Spaces:
Running
Running
Update predict.py
Browse files- predict.py +2 -2
predict.py
CHANGED
|
@@ -10,13 +10,13 @@ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
|
| 10 |
from detectron2.data import MetadataCatalog
|
| 11 |
from detectron2.utils.visualizer import ColorMode, Visualizer
|
| 12 |
from color_palette import ade_palette
|
| 13 |
-
from transformers import
|
| 14 |
|
| 15 |
def load_model_and_processor(model_ckpt: str):
|
| 16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
|
| 18 |
model.eval()
|
| 19 |
-
image_preprocessor =
|
| 20 |
return model, image_preprocessor
|
| 21 |
|
| 22 |
def load_default_ckpt(segmentation_task: str):
|
|
|
|
| 10 |
from detectron2.data import MetadataCatalog
|
| 11 |
from detectron2.utils.visualizer import ColorMode, Visualizer
|
| 12 |
from color_palette import ade_palette
|
| 13 |
+
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
|
| 14 |
|
| 15 |
def load_model_and_processor(model_ckpt: str):
|
| 16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
|
| 18 |
model.eval()
|
| 19 |
+
image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
|
| 20 |
return model, image_preprocessor
|
| 21 |
|
| 22 |
def load_default_ckpt(segmentation_task: str):
|