Cityscapes Segmentation (8 classes) — SERNet (TorchScript)

Modèle de segmentation sémantique pour scènes urbaines (Cityscapes), 8 classes : void, flat, construction, object, nature, sky, human, vehicle.

  • Architecture : SERNet (compatible API DeepLabV3 / torchvision)
  • Format des poids : TorchScriptsernet_model.pt
  • Métriques (val) : mIoU ≈ 73%, wmIoU ≈ 85%
  • Cible : CPU (fonctionne aussi sur GPU si déplacé)

🧩 Classes (IDs)

  • 0: void
  • 1: flat
  • 2: construction
  • 3: object
  • 4: nature
  • 5: sky
  • 6: human
  • 7: vehicle

🛠️ Preprocessing (torchvision)

Le modèle utilise les transforms officiels associés aux poids torchvision :

from torchvision.models.segmentation import DeepLabV3_ResNet101_Weights

weights = DeepLabV3_ResNet101_Weights.DEFAULT
preprocess = weights.transforms()  # PIL -> Tensor + normalisation ImageNet
  • Taille d’entrée : non imposée. Pour la latence CPU, redimensionner avant preprocess, p.ex. (H, W) = (480, 960).
  • Sortie : logits (B, 8, H, W)argmax(dim=1) donne un masque (H, W) d’IDs de classe.

📦 Utilisation — via Hugging Face Hub (TorchScript)

from huggingface_hub import hf_hub_download
from torchvision.models.segmentation import DeepLabV3_ResNet101_Weights
from PIL import Image
import torch, numpy as np

REPO_ID  = "<votre-user>/p9-cityscapes-model"
FILENAME = "sernet_model.pt"

# 1) Charger le modèle TorchScript
path  = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)  # cache auto (~/.cache/huggingface)
model = torch.jit.load(path, map_location="cpu").eval()

# 2) Préprocess (torchvision)
weights    = DeepLabV3_ResNet101_Weights.DEFAULT
preprocess = weights.transforms()

# 3) Préparer l’image
Resample = getattr(Image, "Resampling", Image)  # compat Pillow<9.1
img = Image.open("demo.jpg").convert("RGB")
img = img.resize((960, 480), Resample.BILINEAR)  # optionnel (latence CPU)
x   = preprocess(img).unsqueeze(0)  # [1,C,H,W]

# 4) Prédire
with torch.inference_mode():
    out = model(x)

logits = out["out"] if isinstance(out, dict) else out  # (1,8,H,W)
seg    = torch.argmax(logits, 1).squeeze(0).cpu().numpy().astype("uint8")
print("mask:", seg.shape, "classes:", np.unique(seg))

🎨 Palette (facultatif)

import numpy as np
from matplotlib import colors

PALETTE = ['b','g','r','c','m','y','k','w']  # 0..7

def colorize(seg: np.ndarray) -> np.ndarray:
    h, w = seg.shape
    out = np.zeros((h, w, 3), dtype=np.float32)
    for cid in range(8):
        mask = (seg == cid)
        r, g, b = colors.to_rgb(PALETTE[cid])
        out[mask, 0] = r; out[mask, 1] = g; out[mask, 2] = b
    return (out * 255).astype(np.uint8)

📊 Métriques

  • SERNet (TorchScript) : mIoU ≈ 73%, wmIoU ≈ 85% sur Cityscapes (val).
  • wmIoU = mIoU pondérée (pondérations par classe).

⚠️ Limites & conseils

  • Conçu pour scènes urbaines (Cityscapes) : généralisation limitée hors domaine.
  • CPU : privilégier des entrées ≤ 960×480 pour une latence raisonnable.
  • Pour reproductibilité stricte, pinner la révision HF (tag/commit).

📚 Références

  • Cordts et al., Cityscapes — CVPR 2016
  • Chen et al., DeepLabRethinking Atrous Convolution for Semantic Image Segmentation
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Space using juliengatineau/SERNet_cityscapes_trained 1