Spaces:
Sleeping
Sleeping
Commit
·
562f83f
0
Parent(s):
Initial Space setup - models download from HF Hub
Browse files- .gitattributes +36 -0
- .gitignore +42 -0
- ICARp_Updated.svg +1 -0
- README.md +165 -0
- app.py +575 -0
- packages.txt +3 -0
- requirements.txt +8 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
**/*.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
|
| 11 |
+
# Virtual environments
|
| 12 |
+
venv/
|
| 13 |
+
env/
|
| 14 |
+
ENV/
|
| 15 |
+
|
| 16 |
+
# HuggingFace cache
|
| 17 |
+
hf_cache/
|
| 18 |
+
.cache/
|
| 19 |
+
|
| 20 |
+
# Gradio
|
| 21 |
+
gradio_cached_examples/
|
| 22 |
+
flagged/
|
| 23 |
+
|
| 24 |
+
# Model checkpoints (if testing locally)
|
| 25 |
+
*.pt
|
| 26 |
+
*.pth
|
| 27 |
+
*.ckpt
|
| 28 |
+
|
| 29 |
+
# IDE
|
| 30 |
+
.vscode/
|
| 31 |
+
.idea/
|
| 32 |
+
*.swp
|
| 33 |
+
*.swo
|
| 34 |
+
|
| 35 |
+
# OS
|
| 36 |
+
.DS_Store
|
| 37 |
+
Thumbs.db
|
| 38 |
+
|
| 39 |
+
# Logs
|
| 40 |
+
*.log
|
| 41 |
+
|
| 42 |
+
ft_models/
|
ICARp_Updated.svg
ADDED
|
|
README.md
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🫀 Coronary Artery Segmentation with Fine-tuned SAM2
|
| 2 |
+
|
| 3 |
+
An interactive tool for segmenting coronary arteries in X-ray angiography images using fine-tuned Segment Anything Model 2 (SAM2). This Space provides an easy-to-use interface for medical image segmentation with point-based prompting.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 🎯 How to Use This Space
|
| 8 |
+
|
| 9 |
+
### Step 1: Upload Your Image
|
| 10 |
+
- Click on the image upload area or drag and drop your coronary X-ray angiography image
|
| 11 |
+
- Supported formats: JPG, PNG, and other common image formats
|
| 12 |
+
- The image will be displayed in the interface for annotation
|
| 13 |
+
|
| 14 |
+
### Step 2: Select a Model
|
| 15 |
+
Choose from four fine-tuned SAM2 variants optimized for coronary artery segmentation:
|
| 16 |
+
- **SAM2 Hiera Tiny** - Fastest inference, good for quick results
|
| 17 |
+
- **SAM2 Hiera Small** - Balanced speed and accuracy
|
| 18 |
+
- **SAM2 Hiera Base Plus** - Higher accuracy, moderate speed
|
| 19 |
+
- **SAM2 Hiera Large** - Best accuracy, slower inference
|
| 20 |
+
|
| 21 |
+
### Step 3: Add Point Prompts
|
| 22 |
+
|
| 23 |
+
#### 🟢 Positive Points (Green)
|
| 24 |
+
- Click on areas you **want to include** in the segmentation
|
| 25 |
+
- Use positive points to mark the coronary arteries you want to segment
|
| 26 |
+
- Multiple positive points help the model understand the full extent of the artery
|
| 27 |
+
- **Tip**: Add points along the length of the artery for better coverage
|
| 28 |
+
|
| 29 |
+
#### 🔴 Negative Points (Red)
|
| 30 |
+
- Click on areas you **want to exclude** from the segmentation
|
| 31 |
+
- Use negative points to refine the segmentation and remove unwanted regions
|
| 32 |
+
- Helpful for separating overlapping vessels or removing background
|
| 33 |
+
- **Tip**: Add negative points near the artery boundaries to improve precision
|
| 34 |
+
|
| 35 |
+
#### 🎨 How to Add Points
|
| 36 |
+
1. The first click adds a **positive point** (green)
|
| 37 |
+
2. Hold **Shift** while clicking to add a **negative point** (red)
|
| 38 |
+
3. You can add multiple points of each type
|
| 39 |
+
4. Points are displayed as colored dots on your image
|
| 40 |
+
|
| 41 |
+
### Step 4: Generate Segmentation
|
| 42 |
+
- Click the **"Segment"** button to process your annotations
|
| 43 |
+
- The model will generate a segmentation mask based on your point prompts
|
| 44 |
+
- The result shows the segmented coronary artery highlighted on the original image
|
| 45 |
+
|
| 46 |
+
### Step 5: Refine (Optional)
|
| 47 |
+
- If the segmentation isn't perfect, add more points:
|
| 48 |
+
- Add positive points to include missing artery sections
|
| 49 |
+
- Add negative points to exclude over-segmented regions
|
| 50 |
+
- Click **"Segment"** again to update the result
|
| 51 |
+
- Repeat until you're satisfied with the segmentation
|
| 52 |
+
|
| 53 |
+
### Step 6: Clear and Start Over
|
| 54 |
+
- Use the **"Clear"** button to remove all points and start fresh
|
| 55 |
+
- Upload a new image to segment a different case
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## 💡 Tips for Best Results
|
| 60 |
+
|
| 61 |
+
1. **Start Simple**: Begin with 1-2 positive points on the main artery structure
|
| 62 |
+
2. **Be Strategic**: Place positive points at key locations (branches, endpoints, curves)
|
| 63 |
+
3. **Refine Gradually**: Add negative points only where the model over-segments
|
| 64 |
+
4. **Model Selection**: Start with Tiny or Small models for faster iteration, then use larger models for final results
|
| 65 |
+
5. **Multiple Vessels**: For multiple arteries, focus on one at a time with clear positive points
|
| 66 |
+
6. **Contrast Matters**: Images with good contrast between vessels and background work best
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## 🔬 Model Architecture
|
| 71 |
+
|
| 72 |
+

|
| 73 |
+
|
| 74 |
+
The workflow consists of several key stages:
|
| 75 |
+
1. **Image Preprocessing** - Normalization and enhancement for optimal model input
|
| 76 |
+
2. **Point Prompt Encoding** - Your positive/negative clicks are encoded as spatial prompts
|
| 77 |
+
3. **SAM2 Processing** - Fine-tuned encoder-decoder processes the image with prompts
|
| 78 |
+
4. **Mask Generation** - High-quality segmentation mask output
|
| 79 |
+
5. **Visualization** - Segmented artery overlaid on the original image
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
## 📊 Model Variants
|
| 84 |
+
|
| 85 |
+
| Model | Parameters | Speed | Accuracy | Best For |
|
| 86 |
+
|-------|-----------|-------|----------|----------|
|
| 87 |
+
| Hiera Tiny | ~38M | ⚡⚡⚡ | ⭐⭐⭐ | Quick experiments, real-time feedback |
|
| 88 |
+
| Hiera Small | ~46M | ⚡⚡ | ⭐⭐⭐⭐ | Balanced performance, general use |
|
| 89 |
+
| Hiera Base Plus | ~80M | ⚡ | ⭐⭐⭐⭐⭐ | High-quality results, clinical evaluation |
|
| 90 |
+
| Hiera Large | ~224M | ⚡ | ⭐⭐⭐⭐⭐ | Best possible accuracy, research |
|
| 91 |
+
|
| 92 |
+
All models have been fine-tuned specifically for coronary artery segmentation on X-ray angiography images.
|
| 93 |
+
|
| 94 |
+
---
|
| 95 |
+
|
| 96 |
+
## ⚠️ Disclaimer
|
| 97 |
+
|
| 98 |
+
This tool is intended for **research and educational purposes only**. It is not approved for clinical diagnosis or treatment decisions. Always consult qualified healthcare professionals for medical image interpretation.
|
| 99 |
+
|
| 100 |
+
---
|
| 101 |
+
|
| 102 |
+
## 🚀 Local Installation
|
| 103 |
+
|
| 104 |
+
If you want to run this application locally, follow these setup instructions:
|
| 105 |
+
|
| 106 |
+
### 1. Create Conda Environment
|
| 107 |
+
|
| 108 |
+
```bash
|
| 109 |
+
# Create a new conda environment named sam2_FT_env with Python 3.10
|
| 110 |
+
conda create -n sam2_FT_env python=3.10.0 -y
|
| 111 |
+
|
| 112 |
+
# Activate the environment
|
| 113 |
+
conda activate sam2_FT_env
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
### 2. Install SAM2 Library
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
# Clone the official repository
|
| 120 |
+
git clone https://github.com/facebookresearch/segment-anything-2.git
|
| 121 |
+
cd segment-anything-2
|
| 122 |
+
|
| 123 |
+
# Install the package in editable mode
|
| 124 |
+
pip install -e .
|
| 125 |
+
cd ..
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### 3. Install Dependencies
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
pip install uv
|
| 132 |
+
uv pip install gradio opencv-python-headless torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
### 4. Run the Application
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
# Ensure model checkpoints are in the correct directories
|
| 139 |
+
python app.py
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
Open `http://127.0.0.1:7860` in your browser.
|
| 143 |
+
|
| 144 |
+
---
|
| 145 |
+
|
| 146 |
+
## 📚 Citation
|
| 147 |
+
|
| 148 |
+
If you use this tool in your research, please cite:
|
| 149 |
+
|
| 150 |
+
```bibtex
|
| 151 |
+
@article{ravi2024sam2,
|
| 152 |
+
title={SAM 2: Segment Anything in Images and Videos},
|
| 153 |
+
author={Ravi, Nikhila and others},
|
| 154 |
+
journal={arXiv preprint arXiv:2408.00714},
|
| 155 |
+
year={2024}
|
| 156 |
+
}
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## 📧 Contact & Support
|
| 162 |
+
|
| 163 |
+
For questions, issues, or feedback, please open an issue on the GitHub repository or contact the developers.
|
| 164 |
+
|
| 165 |
+
**Built with ❤️ using SAM2 and Gradio**
|
app.py
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import time
|
| 7 |
+
import functools
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
|
| 11 |
+
# --- Configuration ---
|
| 12 |
+
|
| 13 |
+
# Hugging Face model repositories and filenames
|
| 14 |
+
HF_MODEL_CONFIG = {
|
| 15 |
+
"SAM2 Hiera Tiny": {
|
| 16 |
+
"repo_id": "astroanand/CoronarySAM2",
|
| 17 |
+
"filename": "Coronary_Sam2_t.pt"
|
| 18 |
+
},
|
| 19 |
+
"SAM2 Hiera Small": {
|
| 20 |
+
"repo_id": "astroanand/CoronarySAM2",
|
| 21 |
+
"filename": "Coronary_Sam2_s.pt"
|
| 22 |
+
},
|
| 23 |
+
"SAM2 Hiera Base Plus": {
|
| 24 |
+
"repo_id": "astroanand/CoronarySAM2",
|
| 25 |
+
"filename": "Coronary_Sam2_b+.pt"
|
| 26 |
+
},
|
| 27 |
+
"SAM2 Hiera Large": {
|
| 28 |
+
"repo_id": "astroanand/CoronarySAM2",
|
| 29 |
+
"filename": "Coronary_Sam2_l.pt"
|
| 30 |
+
}
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Download and cache models from Hugging Face
|
| 34 |
+
print("Checking and downloading models from Hugging Face...")
|
| 35 |
+
models_available = {}
|
| 36 |
+
for name, config in HF_MODEL_CONFIG.items():
|
| 37 |
+
try:
|
| 38 |
+
print(f"Downloading {name} from {config['repo_id']}...")
|
| 39 |
+
model_path = hf_hub_download(
|
| 40 |
+
repo_id=config["repo_id"],
|
| 41 |
+
filename=config["filename"],
|
| 42 |
+
cache_dir="./hf_cache"
|
| 43 |
+
)
|
| 44 |
+
models_available[name] = model_path
|
| 45 |
+
print(f"✓ {name} downloaded successfully to {model_path}")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"✗ Warning: Failed to download {name} from {config['repo_id']}: {e}")
|
| 48 |
+
print(f" {name} will not be available in the dropdown.")
|
| 49 |
+
|
| 50 |
+
if not models_available:
|
| 51 |
+
print("Error: No valid models could be downloaded. Please check your internet connection and HF repository access.")
|
| 52 |
+
# exit() # Or handle gracefully
|
| 53 |
+
|
| 54 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
| 55 |
+
print(f"Using device: {DEVICE}")
|
| 56 |
+
|
| 57 |
+
# Try importing SAM2 modules
|
| 58 |
+
try:
|
| 59 |
+
from sam2.build_sam import build_sam2
|
| 60 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 61 |
+
except ImportError:
|
| 62 |
+
print("Error: SAM2 modules not found. Make sure 'sam2' directory is in your Python path or installed.")
|
| 63 |
+
exit()
|
| 64 |
+
|
| 65 |
+
# --- Preprocessing Functions ---
|
| 66 |
+
# ...existing code...
|
| 67 |
+
def normalize_xray_image(image, kernel_size=(51,51), sigma=0):
|
| 68 |
+
"""Normalize X-ray image by applying Gaussian blur and intensity normalization."""
|
| 69 |
+
if image is None: return None
|
| 70 |
+
is_color = len(image.shape) == 3
|
| 71 |
+
if is_color:
|
| 72 |
+
gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 73 |
+
else:
|
| 74 |
+
gray_image = image.copy()
|
| 75 |
+
gray_image = gray_image.astype(float)
|
| 76 |
+
blurred = gray_image.copy()
|
| 77 |
+
for _ in range(5): # Reduced iterations
|
| 78 |
+
blurred = cv2.GaussianBlur(blurred, kernel_size, sigma)
|
| 79 |
+
mean_intensity = np.mean(blurred)
|
| 80 |
+
factor_image = mean_intensity / (blurred + 1e-10)
|
| 81 |
+
if is_color:
|
| 82 |
+
normalized_image = image.copy().astype(float)
|
| 83 |
+
for i in range(3):
|
| 84 |
+
normalized_image[:,:,i] = normalized_image[:,:,i] * factor_image
|
| 85 |
+
else:
|
| 86 |
+
normalized_image = gray_image * factor_image
|
| 87 |
+
return np.clip(normalized_image, 0, 255).astype(np.uint8)
|
| 88 |
+
|
| 89 |
+
def apply_clahe(image_uint8):
|
| 90 |
+
"""Apply CLAHE for better vessel contrast. Expects uint8 input."""
|
| 91 |
+
if image_uint8 is None: return None
|
| 92 |
+
is_color = len(image_uint8.shape) == 3
|
| 93 |
+
|
| 94 |
+
# --- ADJUST CLAHE STRENGTH HERE ---
|
| 95 |
+
# Lower clipLimit reduces the contrast enhancement effect.
|
| 96 |
+
# Original was 2.0. Try values like 1.5, 1.0, or even disable by setting it very low.
|
| 97 |
+
clahe_clip_limit = 2.0
|
| 98 |
+
clahe_tile_grid_size = (8, 8)
|
| 99 |
+
print(f" Applying CLAHE with clipLimit={clahe_clip_limit}, tileGridSize={clahe_tile_grid_size}")
|
| 100 |
+
# ---------------------------------
|
| 101 |
+
|
| 102 |
+
clahe = cv2.createCLAHE(clipLimit=clahe_clip_limit, tileGridSize=clahe_tile_grid_size)
|
| 103 |
+
|
| 104 |
+
if is_color:
|
| 105 |
+
lab = cv2.cvtColor(image_uint8, cv2.COLOR_RGB2LAB)
|
| 106 |
+
l, a, b = cv2.split(lab)
|
| 107 |
+
l_clahe = clahe.apply(l)
|
| 108 |
+
lab_clahe = cv2.merge((l_clahe, a, b))
|
| 109 |
+
clahe_image_uint8 = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
|
| 110 |
+
else:
|
| 111 |
+
clahe_image_uint8 = clahe.apply(image_uint8)
|
| 112 |
+
|
| 113 |
+
# Return uint8 [0, 255] suitable for predictor.set_image
|
| 114 |
+
return clahe_image_uint8
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def preprocess_image_for_sam2(image_rgb_numpy):
|
| 118 |
+
"""Combined preprocessing: normalization + CLAHE for SAM2 input."""
|
| 119 |
+
if image_rgb_numpy is None:
|
| 120 |
+
print("Preprocessing: Input image is None.")
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
start_time = time.time()
|
| 124 |
+
print("Preprocessing Step 1: Normalizing X-ray image...")
|
| 125 |
+
if image_rgb_numpy.dtype != np.uint8:
|
| 126 |
+
image_rgb_numpy = np.clip(image_rgb_numpy, 0, 255).astype(np.uint8)
|
| 127 |
+
if len(image_rgb_numpy.shape) == 2:
|
| 128 |
+
image_rgb_numpy = cv2.cvtColor(image_rgb_numpy, cv2.COLOR_GRAY2RGB)
|
| 129 |
+
|
| 130 |
+
normalized_uint8 = normalize_xray_image(image_rgb_numpy)
|
| 131 |
+
if normalized_uint8 is None:
|
| 132 |
+
print("Preprocessing failed at normalization step.")
|
| 133 |
+
return None
|
| 134 |
+
print(f"Normalization done in {time.time() - start_time:.2f}s")
|
| 135 |
+
|
| 136 |
+
start_time_clahe = time.time()
|
| 137 |
+
print("Preprocessing Step 2: Applying CLAHE...")
|
| 138 |
+
preprocessed_uint8 = apply_clahe(normalized_uint8) # CLAHE applied here
|
| 139 |
+
if preprocessed_uint8 is None:
|
| 140 |
+
print("Preprocessing failed at CLAHE step.")
|
| 141 |
+
return None
|
| 142 |
+
print(f"CLAHE done in {time.time() - start_time_clahe:.2f}s")
|
| 143 |
+
print(f"Total preprocessing time: {time.time() - start_time:.2f}s")
|
| 144 |
+
|
| 145 |
+
return preprocessed_uint8 # Return the image after all steps
|
| 146 |
+
|
| 147 |
+
# --- Model Loading ---
|
| 148 |
+
# ...existing code...
|
| 149 |
+
@functools.lru_cache(maxsize=4) # Cache up to 4 models (one for each variant)
|
| 150 |
+
def load_model(model_name):
|
| 151 |
+
"""Loads the specified SAM2 model and creates a predictor."""
|
| 152 |
+
print(f"\nAttempting to load model: {model_name}")
|
| 153 |
+
if model_name not in models_available: # Check against available models
|
| 154 |
+
print(f"Error: Model name '{model_name}' not found or checkpoint missing.")
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
checkpoint_path = models_available[model_name] # Get path from available dict
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
print(f" Loading checkpoint: {checkpoint_path}")
|
| 161 |
+
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
|
| 162 |
+
if 'model_cfg' not in checkpoint:
|
| 163 |
+
print(f"Error: 'model_cfg' key not found in checkpoint {checkpoint_path}.")
|
| 164 |
+
return None
|
| 165 |
+
model_cfg_name = checkpoint['model_cfg']
|
| 166 |
+
print(f" Using model config from checkpoint: {model_cfg_name}")
|
| 167 |
+
sam2_model = build_sam2(model_cfg_name, checkpoint_path=None, device=DEVICE)
|
| 168 |
+
if 'model_state_dict' not in checkpoint:
|
| 169 |
+
print(f"Error: 'model_state_dict' not found in checkpoint {checkpoint_path}.")
|
| 170 |
+
return None
|
| 171 |
+
state_dict = checkpoint['model_state_dict']
|
| 172 |
+
new_state_dict = {}
|
| 173 |
+
for k, v in state_dict.items():
|
| 174 |
+
name = k[7:] if k.startswith('module.') else k
|
| 175 |
+
new_state_dict[name] = v
|
| 176 |
+
sam2_model.load_state_dict(new_state_dict)
|
| 177 |
+
print(" Successfully loaded fine-tuned model state_dict.")
|
| 178 |
+
sam2_model.to(DEVICE)
|
| 179 |
+
sam2_model.eval()
|
| 180 |
+
predictor = SAM2ImagePredictor(sam2_model)
|
| 181 |
+
print(f"Model '{model_name}' loaded successfully on {DEVICE}.")
|
| 182 |
+
return predictor
|
| 183 |
+
except Exception as e:
|
| 184 |
+
print(f"Error loading model {model_name}: {e}")
|
| 185 |
+
import traceback
|
| 186 |
+
traceback.print_exc()
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
# --- Utility Functions ---
|
| 190 |
+
# ...existing code...
|
| 191 |
+
def resize_image_fixed(image, target_size=1024):
|
| 192 |
+
"""Resizes image to a fixed square size (1024x1024)."""
|
| 193 |
+
if image is None: return None
|
| 194 |
+
return cv2.resize(image, (target_size, target_size), interpolation=cv2.INTER_LINEAR)
|
| 195 |
+
|
| 196 |
+
def draw_points_on_image(image, points_state):
|
| 197 |
+
"""Draws points (green positive, red negative) on a copy of the image."""
|
| 198 |
+
if image is None: return image # Return original if no image
|
| 199 |
+
draw_image = image.copy()
|
| 200 |
+
if not points_state: return draw_image # Return copy if no points
|
| 201 |
+
|
| 202 |
+
# Make points slightly larger and add a black border
|
| 203 |
+
base_radius = max(4, int(min(image.shape[:2]) * 0.006)) # Slightly larger base radius
|
| 204 |
+
border_thickness = 1 # Thickness of the black border
|
| 205 |
+
radius_with_border = base_radius + border_thickness
|
| 206 |
+
thickness = -1 # Filled circle
|
| 207 |
+
|
| 208 |
+
for x, y, label in points_state:
|
| 209 |
+
color = (0, 255, 0) if label == 1 else (255, 0, 0)
|
| 210 |
+
center = (int(x), int(y))
|
| 211 |
+
# Draw black border circle first
|
| 212 |
+
cv2.circle(draw_image, center, radius_with_border, (0, 0, 0), thickness)
|
| 213 |
+
# Draw colored circle on top
|
| 214 |
+
cv2.circle(draw_image, center, base_radius, color, thickness)
|
| 215 |
+
|
| 216 |
+
return draw_image
|
| 217 |
+
|
| 218 |
+
# --- Gradio UI Interaction Functions ---
|
| 219 |
+
|
| 220 |
+
def get_point_counts_text(points_state):
|
| 221 |
+
"""Helper function to generate the point count markdown string."""
|
| 222 |
+
pos_count = sum(1 for _, _, label in points_state if label == 1)
|
| 223 |
+
neg_count = sum(1 for _, _, label in points_state if label == 0)
|
| 224 |
+
return f"**Points Added:** <font color='green'>{pos_count} Positive</font>, <font color='red'>{neg_count} Negative</font>"
|
| 225 |
+
|
| 226 |
+
def add_point(preprocessed_image, points_state, point_type, evt: gr.SelectData):
|
| 227 |
+
"""Callback function when user clicks on the preprocessed image."""
|
| 228 |
+
if preprocessed_image is None:
|
| 229 |
+
gr.Warning("Please upload and preprocess an image first.")
|
| 230 |
+
# Return original image, points state, and existing counts text
|
| 231 |
+
return preprocessed_image, points_state, get_point_counts_text(points_state)
|
| 232 |
+
x, y = evt.index[0], evt.index[1]
|
| 233 |
+
label = 1 if point_type == "Positive" else 0
|
| 234 |
+
# Store coordinates relative to the preprocessed image (1024x1024)
|
| 235 |
+
points_state.append([x, y, label])
|
| 236 |
+
print(f"Added point: ({x}, {y}), Type: {'Positive' if label==1 else 'Negative'}, Total Points: {len(points_state)}")
|
| 237 |
+
image_with_points = draw_points_on_image(preprocessed_image, points_state)
|
| 238 |
+
# Return updated image, points state, and updated counts text
|
| 239 |
+
return image_with_points, points_state, get_point_counts_text(points_state)
|
| 240 |
+
|
| 241 |
+
def undo_last_point(preprocessed_image, points_state):
|
| 242 |
+
"""Removes the last added point and updates the preprocessed display image."""
|
| 243 |
+
if preprocessed_image is None: # Handle case where image is cleared
|
| 244 |
+
# Return None image, points state, and counts text
|
| 245 |
+
return None, points_state, get_point_counts_text(points_state)
|
| 246 |
+
if not points_state:
|
| 247 |
+
print("No points to undo.")
|
| 248 |
+
# Return the current preprocessed image without changes if no points
|
| 249 |
+
return preprocessed_image, points_state, get_point_counts_text(points_state)
|
| 250 |
+
|
| 251 |
+
removed_point = points_state.pop()
|
| 252 |
+
print(f"Removed point: {removed_point}, Remaining Points: {len(points_state)}")
|
| 253 |
+
image_with_points = draw_points_on_image(preprocessed_image, points_state)
|
| 254 |
+
# Return updated image, points state, and updated counts text
|
| 255 |
+
return image_with_points, points_state, get_point_counts_text(points_state)
|
| 256 |
+
|
| 257 |
+
def clear_points_and_display(preprocessed_image_state):
|
| 258 |
+
"""Clears points and resets the preprocessed display image."""
|
| 259 |
+
print("Clearing points and resetting preprocessed display.")
|
| 260 |
+
points_state = [] # Clear points
|
| 261 |
+
# Return the stored preprocessed image without points, clear points state, clear mask, clear counts text
|
| 262 |
+
return preprocessed_image_state, points_state, None, get_point_counts_text(points_state)
|
| 263 |
+
|
| 264 |
+
def run_segmentation(preprocessed_image_state, original_image_state, model_name, points_state):
|
| 265 |
+
"""Runs SAM2 segmentation using points on the preprocessed image."""
|
| 266 |
+
start_total_time = time.time()
|
| 267 |
+
# Initialize return values
|
| 268 |
+
output_mask_display = None
|
| 269 |
+
|
| 270 |
+
if preprocessed_image_state is None or original_image_state is None:
|
| 271 |
+
gr.Warning("Please upload an image first.")
|
| 272 |
+
return output_mask_display, points_state
|
| 273 |
+
|
| 274 |
+
print(f"\n--- Running Segmentation ---")
|
| 275 |
+
print(f" Model Selected: {model_name}")
|
| 276 |
+
print(f" Number of points: {len(points_state)}")
|
| 277 |
+
|
| 278 |
+
# --- 1. Load Model ---
|
| 279 |
+
predictor = load_model(model_name)
|
| 280 |
+
if predictor is None:
|
| 281 |
+
gr.Error(f"Failed to load model '{model_name}'. Check logs and paths.")
|
| 282 |
+
return output_mask_display, points_state
|
| 283 |
+
|
| 284 |
+
# --- 2. Use Preprocessed Image ---
|
| 285 |
+
# The image is already preprocessed and resized to 1024x1024
|
| 286 |
+
image_for_predictor = preprocessed_image_state
|
| 287 |
+
original_h, original_w = original_image_state.shape[:2] # Get original dims for final resize
|
| 288 |
+
print(f" Using preprocessed image (1024x1024) for predictor.")
|
| 289 |
+
print(f" Original image size for final mask resize: {original_w}x{original_h}")
|
| 290 |
+
|
| 291 |
+
print(" Setting preprocessed image in predictor...")
|
| 292 |
+
start_set_image = time.time()
|
| 293 |
+
# Feed the preprocessed image (which is already 1024x1024 uint8) to SAM
|
| 294 |
+
predictor.set_image(image_for_predictor)
|
| 295 |
+
print(f" predictor.set_image took {time.time() - start_set_image:.2f}s")
|
| 296 |
+
|
| 297 |
+
# --- 3. Prepare Prompts (No Scaling Needed) ---
|
| 298 |
+
if not points_state:
|
| 299 |
+
# Use center point if no points provided
|
| 300 |
+
center_x, center_y = 512, 512
|
| 301 |
+
point_coords = np.array([[[center_x, center_y]]])
|
| 302 |
+
point_labels = np.array([1])
|
| 303 |
+
print(" No points provided. Using center point (512, 512).")
|
| 304 |
+
else:
|
| 305 |
+
# Points are already relative to the 1024x1024 preprocessed image
|
| 306 |
+
point_coords_list = [[x, y] for x, y, label in points_state]
|
| 307 |
+
labels_list = [label for x, y, label in points_state]
|
| 308 |
+
point_coords = np.array([point_coords_list])
|
| 309 |
+
point_labels = np.array(labels_list)
|
| 310 |
+
print(f" Using {len(points_state)} provided points (coords relative to 1024x1024).")
|
| 311 |
+
|
| 312 |
+
point_coords_torch = torch.tensor(point_coords, dtype=torch.float32).to(DEVICE)
|
| 313 |
+
point_labels_torch = torch.tensor(point_labels, dtype=torch.float32).unsqueeze(0).to(DEVICE) # Add batch dim
|
| 314 |
+
|
| 315 |
+
# --- 4. Run Model Inference ---
|
| 316 |
+
print(" Running model inference...")
|
| 317 |
+
start_inference_time = time.time()
|
| 318 |
+
with torch.no_grad():
|
| 319 |
+
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
|
| 320 |
+
points=(point_coords_torch, point_labels_torch), boxes=None, masks=None
|
| 321 |
+
)
|
| 322 |
+
if predictor._features is None:
|
| 323 |
+
gr.Error("Image features not computed. Predictor might not have been set correctly.")
|
| 324 |
+
return output_mask_display, points_state
|
| 325 |
+
# Ensure features are accessed correctly
|
| 326 |
+
image_embed = predictor._features["image_embed"][-1].unsqueeze(0)
|
| 327 |
+
image_pe = predictor.model.sam_prompt_encoder.get_dense_pe()
|
| 328 |
+
# Handle potential missing high_res_features key gracefully
|
| 329 |
+
high_res_features = None
|
| 330 |
+
if "high_res_feats" in predictor._features and predictor._features["high_res_feats"]:
|
| 331 |
+
try:
|
| 332 |
+
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
|
| 333 |
+
except IndexError:
|
| 334 |
+
print("Warning: Index error accessing high_res_feats. Proceeding without them.")
|
| 335 |
+
except Exception as e:
|
| 336 |
+
print(f"Warning: Error processing high_res_features: {e}. Proceeding without them.")
|
| 337 |
+
|
| 338 |
+
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
|
| 339 |
+
image_embeddings=image_embed, image_pe=image_pe,
|
| 340 |
+
sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings,
|
| 341 |
+
multimask_output=True, repeat_image=False, # repeat_image should be False for single image prediction
|
| 342 |
+
high_res_features=high_res_features, # Pass None if not available
|
| 343 |
+
)
|
| 344 |
+
# Postprocess masks to 1024x1024
|
| 345 |
+
prd_masks_1024 = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1]) # predictor._orig_hw should be (1024, 1024)
|
| 346 |
+
# Select the best mask based on predicted score
|
| 347 |
+
best_mask_idx = torch.argmax(prd_scores[0]).item()
|
| 348 |
+
# Apply sigmoid and thresholding
|
| 349 |
+
best_mask_1024_prob = torch.sigmoid(prd_masks_1024[:, best_mask_idx])
|
| 350 |
+
binary_mask_1024 = (best_mask_1024_prob > 0.5).cpu().numpy().squeeze() # Squeeze to get (H, W)
|
| 351 |
+
print(f" Model inference took {time.time() - start_inference_time:.2f}s")
|
| 352 |
+
|
| 353 |
+
# --- 5. Resize Mask to Original Dimensions ---
|
| 354 |
+
print(" Resizing mask to original dimensions...")
|
| 355 |
+
final_mask_resized = cv2.resize(
|
| 356 |
+
binary_mask_1024.astype(np.uint8), (original_w, original_h), interpolation=cv2.INTER_NEAREST
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# --- 6. Format Mask for Display ---
|
| 360 |
+
# Mask for display (RGB)
|
| 361 |
+
output_mask_display = (final_mask_resized * 255).astype(np.uint8)
|
| 362 |
+
if len(output_mask_display.shape) == 2: # Ensure RGB for display consistency
|
| 363 |
+
output_mask_display = cv2.cvtColor(output_mask_display, cv2.COLOR_GRAY2RGB)
|
| 364 |
+
|
| 365 |
+
total_time = time.time() - start_total_time
|
| 366 |
+
print(f"--- Segmentation Complete (Total time: {total_time:.2f}s) ---")
|
| 367 |
+
|
| 368 |
+
# Return: mask for display, points state (unchanged)
|
| 369 |
+
return output_mask_display, points_state # No change needed here as it doesn't modify points
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def process_upload(uploaded_image):
|
| 373 |
+
"""Handles image upload: preprocesses, resizes, stores states."""
|
| 374 |
+
if uploaded_image is None:
|
| 375 |
+
# Clear everything including point counts
|
| 376 |
+
return None, None, None, [], None, get_point_counts_text([])
|
| 377 |
+
|
| 378 |
+
print("Image uploaded. Processing...")
|
| 379 |
+
# 1. Store original image
|
| 380 |
+
original_image = uploaded_image.copy()
|
| 381 |
+
|
| 382 |
+
# 2. Resize to 1024x1024 for preprocessing
|
| 383 |
+
image_resized_1024 = resize_image_fixed(original_image, 1024)
|
| 384 |
+
if image_resized_1024 is None:
|
| 385 |
+
gr.Error("Failed to resize image.")
|
| 386 |
+
return None, None, None, [], None
|
| 387 |
+
|
| 388 |
+
# 3. Preprocess the 1024x1024 image
|
| 389 |
+
preprocessed_1024 = preprocess_image_for_sam2(image_resized_1024)
|
| 390 |
+
if preprocessed_1024 is None:
|
| 391 |
+
gr.Error("Image preprocessing failed.")
|
| 392 |
+
return None, None, None, [], None
|
| 393 |
+
|
| 394 |
+
# Ensure preprocessed image is RGB for display
|
| 395 |
+
if len(preprocessed_1024.shape) == 2:
|
| 396 |
+
preprocessed_1024_display = cv2.cvtColor(preprocessed_1024, cv2.COLOR_GRAY2RGB)
|
| 397 |
+
else:
|
| 398 |
+
preprocessed_1024_display = preprocessed_1024.copy()
|
| 399 |
+
|
| 400 |
+
print("Image processed successfully.")
|
| 401 |
+
points_state = [] # Clear points on new upload
|
| 402 |
+
# Return:
|
| 403 |
+
# 1. Preprocessed image for display (interactive)
|
| 404 |
+
# 2. Preprocessed image for state
|
| 405 |
+
# 3. Original image for state
|
| 406 |
+
# 4. Cleared points state
|
| 407 |
+
# 5. Cleared mask display
|
| 408 |
+
# 6. Cleared point counts text
|
| 409 |
+
return preprocessed_1024_display, preprocessed_1024, original_image, points_state, None, get_point_counts_text(points_state)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def clear_all_outputs():
|
| 413 |
+
"""Clears all input/output fields and states."""
|
| 414 |
+
print("Clearing all inputs and outputs.")
|
| 415 |
+
points_state = [] # Clear points
|
| 416 |
+
# Clear everything including point counts
|
| 417 |
+
return None, None, None, points_state, None, get_point_counts_text(points_state)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
# --- Build Gradio Interface ---
|
| 421 |
+
css = """
|
| 422 |
+
#mask_display_container .gradio-image { height: 450px !important; object-fit: contain; }
|
| 423 |
+
#preprocessed_image_container .gradio-image { height: 450px !important; object-fit: contain; cursor: crosshair !important; }
|
| 424 |
+
#upload_container .gradio-image { height: 150px !important; object-fit: contain; } /* Smaller upload preview */
|
| 425 |
+
.output-col img { max-height: 450px; object-fit: contain; }
|
| 426 |
+
.control-col { min-width: 500px; } /* Wider control column */
|
| 427 |
+
.output-col { min-width: 500px; }
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
with gr.Blocks(css=css, title="Coronary Artery Segmentation (Fine-tuned SAM2)") as demo:
|
| 431 |
+
gr.Markdown("# Coronary Artery Segmentation using Fine-tuned SAM2")
|
| 432 |
+
gr.Markdown(
|
| 433 |
+
"**Let's find those arteries!**\n\n"
|
| 434 |
+
"1. Upload your Coronary X-ray Image.\n"
|
| 435 |
+
"2. The preprocessed image appears on the left. Time to guide the AI! Click directly on the image to add **Positive** (artery) or **Negative** (background) points.\n"
|
| 436 |
+
"3. Choose your fine-tuned SAM2 model.\n"
|
| 437 |
+
"4. Hit 'Run Segmentation' and watch the magic happen!\n"
|
| 438 |
+
"5. Download your predicted mask (the white area) using the download button on the mask image."
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# --- States ---
|
| 442 |
+
points_state = gr.State([])
|
| 443 |
+
# State to store the original uploaded image (needed for final mask resizing)
|
| 444 |
+
original_image_state = gr.State(None)
|
| 445 |
+
# State to store the preprocessed 1024x1024 image data (used for drawing points and predictor input)
|
| 446 |
+
preprocessed_image_state = gr.State(None)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
with gr.Row():
|
| 450 |
+
# --- Left Column (Controls & Preprocessed Image Interaction) ---
|
| 451 |
+
with gr.Column(scale=1, elem_classes="control-col"):
|
| 452 |
+
gr.Markdown("## 1. Upload & Controls")
|
| 453 |
+
# Keep upload separate and smaller
|
| 454 |
+
upload_image = gr.Image(
|
| 455 |
+
type="numpy", label="Upload Coronary X-ray Image",
|
| 456 |
+
height=150, elem_id="upload_container"
|
| 457 |
+
)
|
| 458 |
+
gr.Markdown("## 2. Add Points on Preprocessed Image")
|
| 459 |
+
# Interactive Preprocessed Image Display
|
| 460 |
+
preprocessed_image_display = gr.Image(
|
| 461 |
+
type="numpy", label="Click on Image to Add Points",
|
| 462 |
+
interactive=True, # Make this interactive
|
| 463 |
+
height=450, elem_id="preprocessed_image_container"
|
| 464 |
+
)
|
| 465 |
+
# Add Point Counter Display
|
| 466 |
+
point_counter_display = gr.Markdown(get_point_counts_text([]))
|
| 467 |
+
|
| 468 |
+
model_selector = gr.Dropdown(
|
| 469 |
+
choices=list(models_available.keys()),
|
| 470 |
+
label="Select SAM2 Model Variant",
|
| 471 |
+
value=list(models_available.keys())[-1] if models_available else None
|
| 472 |
+
)
|
| 473 |
+
prompt_type = gr.Radio(
|
| 474 |
+
["Positive", "Negative"], label="Point Prompt Type", value="Positive"
|
| 475 |
+
)
|
| 476 |
+
with gr.Row():
|
| 477 |
+
clear_button = gr.Button("Clear Points")
|
| 478 |
+
undo_button = gr.Button("Undo Last Point")
|
| 479 |
+
run_button = gr.Button("Run Segmentation", variant="primary")
|
| 480 |
+
clear_all_button = gr.Button("Clear All") # Added Clear All
|
| 481 |
+
|
| 482 |
+
# --- Right Column (Output Mask) ---
|
| 483 |
+
with gr.Column(scale=1, elem_classes="output-col"):
|
| 484 |
+
gr.Markdown("## 3. Predicted Mask")
|
| 485 |
+
final_mask_display = gr.Image(
|
| 486 |
+
type="numpy", label="Predicted Binary Mask (White = Artery)",
|
| 487 |
+
interactive=False, height=450, elem_id="mask_display_container",
|
| 488 |
+
format="png" # Specify PNG format for download
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
# --- Define Interactions ---
|
| 493 |
+
|
| 494 |
+
# 1. Upload triggers preprocessing and display
|
| 495 |
+
upload_image.upload(
|
| 496 |
+
fn=process_upload,
|
| 497 |
+
inputs=[upload_image],
|
| 498 |
+
outputs=[
|
| 499 |
+
preprocessed_image_display, # Update interactive display
|
| 500 |
+
preprocessed_image_state, # Update state
|
| 501 |
+
original_image_state, # Update state
|
| 502 |
+
points_state, # Clear points
|
| 503 |
+
final_mask_display, # Clear mask display
|
| 504 |
+
point_counter_display # Clear point counts
|
| 505 |
+
]
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# 2. Clicking on preprocessed image adds points
|
| 509 |
+
preprocessed_image_display.select(
|
| 510 |
+
fn=add_point,
|
| 511 |
+
inputs=[preprocessed_image_state, points_state, prompt_type],
|
| 512 |
+
outputs=[
|
| 513 |
+
preprocessed_image_display, # Update display with points
|
| 514 |
+
points_state, # Update points state
|
| 515 |
+
point_counter_display # Update point counts
|
| 516 |
+
]
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# 3. Clear points button resets points and preprocessed display
|
| 520 |
+
clear_button.click(
|
| 521 |
+
fn=clear_points_and_display,
|
| 522 |
+
inputs=[preprocessed_image_state], # Needs the clean preprocessed image
|
| 523 |
+
outputs=[
|
| 524 |
+
preprocessed_image_display, # Reset display
|
| 525 |
+
points_state, # Clear points
|
| 526 |
+
final_mask_display, # Clear mask
|
| 527 |
+
point_counter_display # Reset point counts
|
| 528 |
+
]
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
# 4. Undo button removes last point and updates preprocessed display
|
| 532 |
+
undo_button.click(
|
| 533 |
+
fn=undo_last_point,
|
| 534 |
+
inputs=[preprocessed_image_state, points_state], # Needs current preprocessed image
|
| 535 |
+
outputs=[
|
| 536 |
+
preprocessed_image_display, # Update display
|
| 537 |
+
points_state, # Update points state
|
| 538 |
+
point_counter_display # Update point counts
|
| 539 |
+
]
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# 5. Run segmentation (Outputs don't change point counts)
|
| 543 |
+
run_button.click(
|
| 544 |
+
fn=run_segmentation,
|
| 545 |
+
inputs=[
|
| 546 |
+
preprocessed_image_state, # Use preprocessed image data
|
| 547 |
+
original_image_state, # Needed for final resize dim
|
| 548 |
+
model_selector,
|
| 549 |
+
points_state # Points are relative to preprocessed
|
| 550 |
+
],
|
| 551 |
+
outputs=[
|
| 552 |
+
final_mask_display, # Show the final mask
|
| 553 |
+
points_state # Pass points state (might be needed if run modifies it - currently doesn't)
|
| 554 |
+
]
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# 7. Clear All button
|
| 558 |
+
clear_all_button.click(
|
| 559 |
+
fn=clear_all_outputs,
|
| 560 |
+
inputs=[],
|
| 561 |
+
outputs=[
|
| 562 |
+
upload_image,
|
| 563 |
+
preprocessed_image_display,
|
| 564 |
+
preprocessed_image_state,
|
| 565 |
+
points_state,
|
| 566 |
+
final_mask_display,
|
| 567 |
+
point_counter_display # Reset point counts
|
| 568 |
+
]
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
# --- Launch the App ---
|
| 573 |
+
if __name__ == "__main__":
|
| 574 |
+
print("Launching Gradio App...")
|
| 575 |
+
demo.launch()
|
packages.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libgl1
|
| 2 |
+
libglib2.0-0
|
| 3 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
numpy
|
| 5 |
+
opencv-python-headless
|
| 6 |
+
Pillow
|
| 7 |
+
huggingface_hub
|
| 8 |
+
git+https://github.com/facebookresearch/segment-anything-2.git
|