anudef commited on
Commit
562f83f
·
0 Parent(s):

Initial Space setup - models download from HF Hub

Browse files
Files changed (7) hide show
  1. .gitattributes +36 -0
  2. .gitignore +42 -0
  3. ICARp_Updated.svg +1 -0
  4. README.md +165 -0
  5. app.py +575 -0
  6. packages.txt +3 -0
  7. requirements.txt +8 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ **/*.pt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+
11
+ # Virtual environments
12
+ venv/
13
+ env/
14
+ ENV/
15
+
16
+ # HuggingFace cache
17
+ hf_cache/
18
+ .cache/
19
+
20
+ # Gradio
21
+ gradio_cached_examples/
22
+ flagged/
23
+
24
+ # Model checkpoints (if testing locally)
25
+ *.pt
26
+ *.pth
27
+ *.ckpt
28
+
29
+ # IDE
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+
35
+ # OS
36
+ .DS_Store
37
+ Thumbs.db
38
+
39
+ # Logs
40
+ *.log
41
+
42
+ ft_models/
ICARp_Updated.svg ADDED
README.md ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🫀 Coronary Artery Segmentation with Fine-tuned SAM2
2
+
3
+ An interactive tool for segmenting coronary arteries in X-ray angiography images using fine-tuned Segment Anything Model 2 (SAM2). This Space provides an easy-to-use interface for medical image segmentation with point-based prompting.
4
+
5
+ ---
6
+
7
+ ## 🎯 How to Use This Space
8
+
9
+ ### Step 1: Upload Your Image
10
+ - Click on the image upload area or drag and drop your coronary X-ray angiography image
11
+ - Supported formats: JPG, PNG, and other common image formats
12
+ - The image will be displayed in the interface for annotation
13
+
14
+ ### Step 2: Select a Model
15
+ Choose from four fine-tuned SAM2 variants optimized for coronary artery segmentation:
16
+ - **SAM2 Hiera Tiny** - Fastest inference, good for quick results
17
+ - **SAM2 Hiera Small** - Balanced speed and accuracy
18
+ - **SAM2 Hiera Base Plus** - Higher accuracy, moderate speed
19
+ - **SAM2 Hiera Large** - Best accuracy, slower inference
20
+
21
+ ### Step 3: Add Point Prompts
22
+
23
+ #### 🟢 Positive Points (Green)
24
+ - Click on areas you **want to include** in the segmentation
25
+ - Use positive points to mark the coronary arteries you want to segment
26
+ - Multiple positive points help the model understand the full extent of the artery
27
+ - **Tip**: Add points along the length of the artery for better coverage
28
+
29
+ #### 🔴 Negative Points (Red)
30
+ - Click on areas you **want to exclude** from the segmentation
31
+ - Use negative points to refine the segmentation and remove unwanted regions
32
+ - Helpful for separating overlapping vessels or removing background
33
+ - **Tip**: Add negative points near the artery boundaries to improve precision
34
+
35
+ #### 🎨 How to Add Points
36
+ 1. The first click adds a **positive point** (green)
37
+ 2. Hold **Shift** while clicking to add a **negative point** (red)
38
+ 3. You can add multiple points of each type
39
+ 4. Points are displayed as colored dots on your image
40
+
41
+ ### Step 4: Generate Segmentation
42
+ - Click the **"Segment"** button to process your annotations
43
+ - The model will generate a segmentation mask based on your point prompts
44
+ - The result shows the segmented coronary artery highlighted on the original image
45
+
46
+ ### Step 5: Refine (Optional)
47
+ - If the segmentation isn't perfect, add more points:
48
+ - Add positive points to include missing artery sections
49
+ - Add negative points to exclude over-segmented regions
50
+ - Click **"Segment"** again to update the result
51
+ - Repeat until you're satisfied with the segmentation
52
+
53
+ ### Step 6: Clear and Start Over
54
+ - Use the **"Clear"** button to remove all points and start fresh
55
+ - Upload a new image to segment a different case
56
+
57
+ ---
58
+
59
+ ## 💡 Tips for Best Results
60
+
61
+ 1. **Start Simple**: Begin with 1-2 positive points on the main artery structure
62
+ 2. **Be Strategic**: Place positive points at key locations (branches, endpoints, curves)
63
+ 3. **Refine Gradually**: Add negative points only where the model over-segments
64
+ 4. **Model Selection**: Start with Tiny or Small models for faster iteration, then use larger models for final results
65
+ 5. **Multiple Vessels**: For multiple arteries, focus on one at a time with clear positive points
66
+ 6. **Contrast Matters**: Images with good contrast between vessels and background work best
67
+
68
+ ---
69
+
70
+ ## 🔬 Model Architecture
71
+
72
+ ![Coronary Artery Segmentation Workflow](ICARp_Updated.svg)
73
+
74
+ The workflow consists of several key stages:
75
+ 1. **Image Preprocessing** - Normalization and enhancement for optimal model input
76
+ 2. **Point Prompt Encoding** - Your positive/negative clicks are encoded as spatial prompts
77
+ 3. **SAM2 Processing** - Fine-tuned encoder-decoder processes the image with prompts
78
+ 4. **Mask Generation** - High-quality segmentation mask output
79
+ 5. **Visualization** - Segmented artery overlaid on the original image
80
+
81
+ ---
82
+
83
+ ## 📊 Model Variants
84
+
85
+ | Model | Parameters | Speed | Accuracy | Best For |
86
+ |-------|-----------|-------|----------|----------|
87
+ | Hiera Tiny | ~38M | ⚡⚡⚡ | ⭐⭐⭐ | Quick experiments, real-time feedback |
88
+ | Hiera Small | ~46M | ⚡⚡ | ⭐⭐⭐⭐ | Balanced performance, general use |
89
+ | Hiera Base Plus | ~80M | ⚡ | ⭐⭐⭐⭐⭐ | High-quality results, clinical evaluation |
90
+ | Hiera Large | ~224M | ⚡ | ⭐⭐⭐⭐⭐ | Best possible accuracy, research |
91
+
92
+ All models have been fine-tuned specifically for coronary artery segmentation on X-ray angiography images.
93
+
94
+ ---
95
+
96
+ ## ⚠️ Disclaimer
97
+
98
+ This tool is intended for **research and educational purposes only**. It is not approved for clinical diagnosis or treatment decisions. Always consult qualified healthcare professionals for medical image interpretation.
99
+
100
+ ---
101
+
102
+ ## 🚀 Local Installation
103
+
104
+ If you want to run this application locally, follow these setup instructions:
105
+
106
+ ### 1. Create Conda Environment
107
+
108
+ ```bash
109
+ # Create a new conda environment named sam2_FT_env with Python 3.10
110
+ conda create -n sam2_FT_env python=3.10.0 -y
111
+
112
+ # Activate the environment
113
+ conda activate sam2_FT_env
114
+ ```
115
+
116
+ ### 2. Install SAM2 Library
117
+
118
+ ```bash
119
+ # Clone the official repository
120
+ git clone https://github.com/facebookresearch/segment-anything-2.git
121
+ cd segment-anything-2
122
+
123
+ # Install the package in editable mode
124
+ pip install -e .
125
+ cd ..
126
+ ```
127
+
128
+ ### 3. Install Dependencies
129
+
130
+ ```bash
131
+ pip install uv
132
+ uv pip install gradio opencv-python-headless torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
133
+ ```
134
+
135
+ ### 4. Run the Application
136
+
137
+ ```bash
138
+ # Ensure model checkpoints are in the correct directories
139
+ python app.py
140
+ ```
141
+
142
+ Open `http://127.0.0.1:7860` in your browser.
143
+
144
+ ---
145
+
146
+ ## 📚 Citation
147
+
148
+ If you use this tool in your research, please cite:
149
+
150
+ ```bibtex
151
+ @article{ravi2024sam2,
152
+ title={SAM 2: Segment Anything in Images and Videos},
153
+ author={Ravi, Nikhila and others},
154
+ journal={arXiv preprint arXiv:2408.00714},
155
+ year={2024}
156
+ }
157
+ ```
158
+
159
+ ---
160
+
161
+ ## 📧 Contact & Support
162
+
163
+ For questions, issues, or feedback, please open an issue on the GitHub repository or contact the developers.
164
+
165
+ **Built with ❤️ using SAM2 and Gradio**
app.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ import time
7
+ import functools
8
+ from PIL import Image
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ # --- Configuration ---
12
+
13
+ # Hugging Face model repositories and filenames
14
+ HF_MODEL_CONFIG = {
15
+ "SAM2 Hiera Tiny": {
16
+ "repo_id": "astroanand/CoronarySAM2",
17
+ "filename": "Coronary_Sam2_t.pt"
18
+ },
19
+ "SAM2 Hiera Small": {
20
+ "repo_id": "astroanand/CoronarySAM2",
21
+ "filename": "Coronary_Sam2_s.pt"
22
+ },
23
+ "SAM2 Hiera Base Plus": {
24
+ "repo_id": "astroanand/CoronarySAM2",
25
+ "filename": "Coronary_Sam2_b+.pt"
26
+ },
27
+ "SAM2 Hiera Large": {
28
+ "repo_id": "astroanand/CoronarySAM2",
29
+ "filename": "Coronary_Sam2_l.pt"
30
+ }
31
+ }
32
+
33
+ # Download and cache models from Hugging Face
34
+ print("Checking and downloading models from Hugging Face...")
35
+ models_available = {}
36
+ for name, config in HF_MODEL_CONFIG.items():
37
+ try:
38
+ print(f"Downloading {name} from {config['repo_id']}...")
39
+ model_path = hf_hub_download(
40
+ repo_id=config["repo_id"],
41
+ filename=config["filename"],
42
+ cache_dir="./hf_cache"
43
+ )
44
+ models_available[name] = model_path
45
+ print(f"✓ {name} downloaded successfully to {model_path}")
46
+ except Exception as e:
47
+ print(f"✗ Warning: Failed to download {name} from {config['repo_id']}: {e}")
48
+ print(f" {name} will not be available in the dropdown.")
49
+
50
+ if not models_available:
51
+ print("Error: No valid models could be downloaded. Please check your internet connection and HF repository access.")
52
+ # exit() # Or handle gracefully
53
+
54
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
55
+ print(f"Using device: {DEVICE}")
56
+
57
+ # Try importing SAM2 modules
58
+ try:
59
+ from sam2.build_sam import build_sam2
60
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
61
+ except ImportError:
62
+ print("Error: SAM2 modules not found. Make sure 'sam2' directory is in your Python path or installed.")
63
+ exit()
64
+
65
+ # --- Preprocessing Functions ---
66
+ # ...existing code...
67
+ def normalize_xray_image(image, kernel_size=(51,51), sigma=0):
68
+ """Normalize X-ray image by applying Gaussian blur and intensity normalization."""
69
+ if image is None: return None
70
+ is_color = len(image.shape) == 3
71
+ if is_color:
72
+ gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
73
+ else:
74
+ gray_image = image.copy()
75
+ gray_image = gray_image.astype(float)
76
+ blurred = gray_image.copy()
77
+ for _ in range(5): # Reduced iterations
78
+ blurred = cv2.GaussianBlur(blurred, kernel_size, sigma)
79
+ mean_intensity = np.mean(blurred)
80
+ factor_image = mean_intensity / (blurred + 1e-10)
81
+ if is_color:
82
+ normalized_image = image.copy().astype(float)
83
+ for i in range(3):
84
+ normalized_image[:,:,i] = normalized_image[:,:,i] * factor_image
85
+ else:
86
+ normalized_image = gray_image * factor_image
87
+ return np.clip(normalized_image, 0, 255).astype(np.uint8)
88
+
89
+ def apply_clahe(image_uint8):
90
+ """Apply CLAHE for better vessel contrast. Expects uint8 input."""
91
+ if image_uint8 is None: return None
92
+ is_color = len(image_uint8.shape) == 3
93
+
94
+ # --- ADJUST CLAHE STRENGTH HERE ---
95
+ # Lower clipLimit reduces the contrast enhancement effect.
96
+ # Original was 2.0. Try values like 1.5, 1.0, or even disable by setting it very low.
97
+ clahe_clip_limit = 2.0
98
+ clahe_tile_grid_size = (8, 8)
99
+ print(f" Applying CLAHE with clipLimit={clahe_clip_limit}, tileGridSize={clahe_tile_grid_size}")
100
+ # ---------------------------------
101
+
102
+ clahe = cv2.createCLAHE(clipLimit=clahe_clip_limit, tileGridSize=clahe_tile_grid_size)
103
+
104
+ if is_color:
105
+ lab = cv2.cvtColor(image_uint8, cv2.COLOR_RGB2LAB)
106
+ l, a, b = cv2.split(lab)
107
+ l_clahe = clahe.apply(l)
108
+ lab_clahe = cv2.merge((l_clahe, a, b))
109
+ clahe_image_uint8 = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
110
+ else:
111
+ clahe_image_uint8 = clahe.apply(image_uint8)
112
+
113
+ # Return uint8 [0, 255] suitable for predictor.set_image
114
+ return clahe_image_uint8
115
+
116
+
117
+ def preprocess_image_for_sam2(image_rgb_numpy):
118
+ """Combined preprocessing: normalization + CLAHE for SAM2 input."""
119
+ if image_rgb_numpy is None:
120
+ print("Preprocessing: Input image is None.")
121
+ return None
122
+
123
+ start_time = time.time()
124
+ print("Preprocessing Step 1: Normalizing X-ray image...")
125
+ if image_rgb_numpy.dtype != np.uint8:
126
+ image_rgb_numpy = np.clip(image_rgb_numpy, 0, 255).astype(np.uint8)
127
+ if len(image_rgb_numpy.shape) == 2:
128
+ image_rgb_numpy = cv2.cvtColor(image_rgb_numpy, cv2.COLOR_GRAY2RGB)
129
+
130
+ normalized_uint8 = normalize_xray_image(image_rgb_numpy)
131
+ if normalized_uint8 is None:
132
+ print("Preprocessing failed at normalization step.")
133
+ return None
134
+ print(f"Normalization done in {time.time() - start_time:.2f}s")
135
+
136
+ start_time_clahe = time.time()
137
+ print("Preprocessing Step 2: Applying CLAHE...")
138
+ preprocessed_uint8 = apply_clahe(normalized_uint8) # CLAHE applied here
139
+ if preprocessed_uint8 is None:
140
+ print("Preprocessing failed at CLAHE step.")
141
+ return None
142
+ print(f"CLAHE done in {time.time() - start_time_clahe:.2f}s")
143
+ print(f"Total preprocessing time: {time.time() - start_time:.2f}s")
144
+
145
+ return preprocessed_uint8 # Return the image after all steps
146
+
147
+ # --- Model Loading ---
148
+ # ...existing code...
149
+ @functools.lru_cache(maxsize=4) # Cache up to 4 models (one for each variant)
150
+ def load_model(model_name):
151
+ """Loads the specified SAM2 model and creates a predictor."""
152
+ print(f"\nAttempting to load model: {model_name}")
153
+ if model_name not in models_available: # Check against available models
154
+ print(f"Error: Model name '{model_name}' not found or checkpoint missing.")
155
+ return None
156
+
157
+ checkpoint_path = models_available[model_name] # Get path from available dict
158
+
159
+ try:
160
+ print(f" Loading checkpoint: {checkpoint_path}")
161
+ checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
162
+ if 'model_cfg' not in checkpoint:
163
+ print(f"Error: 'model_cfg' key not found in checkpoint {checkpoint_path}.")
164
+ return None
165
+ model_cfg_name = checkpoint['model_cfg']
166
+ print(f" Using model config from checkpoint: {model_cfg_name}")
167
+ sam2_model = build_sam2(model_cfg_name, checkpoint_path=None, device=DEVICE)
168
+ if 'model_state_dict' not in checkpoint:
169
+ print(f"Error: 'model_state_dict' not found in checkpoint {checkpoint_path}.")
170
+ return None
171
+ state_dict = checkpoint['model_state_dict']
172
+ new_state_dict = {}
173
+ for k, v in state_dict.items():
174
+ name = k[7:] if k.startswith('module.') else k
175
+ new_state_dict[name] = v
176
+ sam2_model.load_state_dict(new_state_dict)
177
+ print(" Successfully loaded fine-tuned model state_dict.")
178
+ sam2_model.to(DEVICE)
179
+ sam2_model.eval()
180
+ predictor = SAM2ImagePredictor(sam2_model)
181
+ print(f"Model '{model_name}' loaded successfully on {DEVICE}.")
182
+ return predictor
183
+ except Exception as e:
184
+ print(f"Error loading model {model_name}: {e}")
185
+ import traceback
186
+ traceback.print_exc()
187
+ return None
188
+
189
+ # --- Utility Functions ---
190
+ # ...existing code...
191
+ def resize_image_fixed(image, target_size=1024):
192
+ """Resizes image to a fixed square size (1024x1024)."""
193
+ if image is None: return None
194
+ return cv2.resize(image, (target_size, target_size), interpolation=cv2.INTER_LINEAR)
195
+
196
+ def draw_points_on_image(image, points_state):
197
+ """Draws points (green positive, red negative) on a copy of the image."""
198
+ if image is None: return image # Return original if no image
199
+ draw_image = image.copy()
200
+ if not points_state: return draw_image # Return copy if no points
201
+
202
+ # Make points slightly larger and add a black border
203
+ base_radius = max(4, int(min(image.shape[:2]) * 0.006)) # Slightly larger base radius
204
+ border_thickness = 1 # Thickness of the black border
205
+ radius_with_border = base_radius + border_thickness
206
+ thickness = -1 # Filled circle
207
+
208
+ for x, y, label in points_state:
209
+ color = (0, 255, 0) if label == 1 else (255, 0, 0)
210
+ center = (int(x), int(y))
211
+ # Draw black border circle first
212
+ cv2.circle(draw_image, center, radius_with_border, (0, 0, 0), thickness)
213
+ # Draw colored circle on top
214
+ cv2.circle(draw_image, center, base_radius, color, thickness)
215
+
216
+ return draw_image
217
+
218
+ # --- Gradio UI Interaction Functions ---
219
+
220
+ def get_point_counts_text(points_state):
221
+ """Helper function to generate the point count markdown string."""
222
+ pos_count = sum(1 for _, _, label in points_state if label == 1)
223
+ neg_count = sum(1 for _, _, label in points_state if label == 0)
224
+ return f"**Points Added:** <font color='green'>{pos_count} Positive</font>, <font color='red'>{neg_count} Negative</font>"
225
+
226
+ def add_point(preprocessed_image, points_state, point_type, evt: gr.SelectData):
227
+ """Callback function when user clicks on the preprocessed image."""
228
+ if preprocessed_image is None:
229
+ gr.Warning("Please upload and preprocess an image first.")
230
+ # Return original image, points state, and existing counts text
231
+ return preprocessed_image, points_state, get_point_counts_text(points_state)
232
+ x, y = evt.index[0], evt.index[1]
233
+ label = 1 if point_type == "Positive" else 0
234
+ # Store coordinates relative to the preprocessed image (1024x1024)
235
+ points_state.append([x, y, label])
236
+ print(f"Added point: ({x}, {y}), Type: {'Positive' if label==1 else 'Negative'}, Total Points: {len(points_state)}")
237
+ image_with_points = draw_points_on_image(preprocessed_image, points_state)
238
+ # Return updated image, points state, and updated counts text
239
+ return image_with_points, points_state, get_point_counts_text(points_state)
240
+
241
+ def undo_last_point(preprocessed_image, points_state):
242
+ """Removes the last added point and updates the preprocessed display image."""
243
+ if preprocessed_image is None: # Handle case where image is cleared
244
+ # Return None image, points state, and counts text
245
+ return None, points_state, get_point_counts_text(points_state)
246
+ if not points_state:
247
+ print("No points to undo.")
248
+ # Return the current preprocessed image without changes if no points
249
+ return preprocessed_image, points_state, get_point_counts_text(points_state)
250
+
251
+ removed_point = points_state.pop()
252
+ print(f"Removed point: {removed_point}, Remaining Points: {len(points_state)}")
253
+ image_with_points = draw_points_on_image(preprocessed_image, points_state)
254
+ # Return updated image, points state, and updated counts text
255
+ return image_with_points, points_state, get_point_counts_text(points_state)
256
+
257
+ def clear_points_and_display(preprocessed_image_state):
258
+ """Clears points and resets the preprocessed display image."""
259
+ print("Clearing points and resetting preprocessed display.")
260
+ points_state = [] # Clear points
261
+ # Return the stored preprocessed image without points, clear points state, clear mask, clear counts text
262
+ return preprocessed_image_state, points_state, None, get_point_counts_text(points_state)
263
+
264
+ def run_segmentation(preprocessed_image_state, original_image_state, model_name, points_state):
265
+ """Runs SAM2 segmentation using points on the preprocessed image."""
266
+ start_total_time = time.time()
267
+ # Initialize return values
268
+ output_mask_display = None
269
+
270
+ if preprocessed_image_state is None or original_image_state is None:
271
+ gr.Warning("Please upload an image first.")
272
+ return output_mask_display, points_state
273
+
274
+ print(f"\n--- Running Segmentation ---")
275
+ print(f" Model Selected: {model_name}")
276
+ print(f" Number of points: {len(points_state)}")
277
+
278
+ # --- 1. Load Model ---
279
+ predictor = load_model(model_name)
280
+ if predictor is None:
281
+ gr.Error(f"Failed to load model '{model_name}'. Check logs and paths.")
282
+ return output_mask_display, points_state
283
+
284
+ # --- 2. Use Preprocessed Image ---
285
+ # The image is already preprocessed and resized to 1024x1024
286
+ image_for_predictor = preprocessed_image_state
287
+ original_h, original_w = original_image_state.shape[:2] # Get original dims for final resize
288
+ print(f" Using preprocessed image (1024x1024) for predictor.")
289
+ print(f" Original image size for final mask resize: {original_w}x{original_h}")
290
+
291
+ print(" Setting preprocessed image in predictor...")
292
+ start_set_image = time.time()
293
+ # Feed the preprocessed image (which is already 1024x1024 uint8) to SAM
294
+ predictor.set_image(image_for_predictor)
295
+ print(f" predictor.set_image took {time.time() - start_set_image:.2f}s")
296
+
297
+ # --- 3. Prepare Prompts (No Scaling Needed) ---
298
+ if not points_state:
299
+ # Use center point if no points provided
300
+ center_x, center_y = 512, 512
301
+ point_coords = np.array([[[center_x, center_y]]])
302
+ point_labels = np.array([1])
303
+ print(" No points provided. Using center point (512, 512).")
304
+ else:
305
+ # Points are already relative to the 1024x1024 preprocessed image
306
+ point_coords_list = [[x, y] for x, y, label in points_state]
307
+ labels_list = [label for x, y, label in points_state]
308
+ point_coords = np.array([point_coords_list])
309
+ point_labels = np.array(labels_list)
310
+ print(f" Using {len(points_state)} provided points (coords relative to 1024x1024).")
311
+
312
+ point_coords_torch = torch.tensor(point_coords, dtype=torch.float32).to(DEVICE)
313
+ point_labels_torch = torch.tensor(point_labels, dtype=torch.float32).unsqueeze(0).to(DEVICE) # Add batch dim
314
+
315
+ # --- 4. Run Model Inference ---
316
+ print(" Running model inference...")
317
+ start_inference_time = time.time()
318
+ with torch.no_grad():
319
+ sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
320
+ points=(point_coords_torch, point_labels_torch), boxes=None, masks=None
321
+ )
322
+ if predictor._features is None:
323
+ gr.Error("Image features not computed. Predictor might not have been set correctly.")
324
+ return output_mask_display, points_state
325
+ # Ensure features are accessed correctly
326
+ image_embed = predictor._features["image_embed"][-1].unsqueeze(0)
327
+ image_pe = predictor.model.sam_prompt_encoder.get_dense_pe()
328
+ # Handle potential missing high_res_features key gracefully
329
+ high_res_features = None
330
+ if "high_res_feats" in predictor._features and predictor._features["high_res_feats"]:
331
+ try:
332
+ high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
333
+ except IndexError:
334
+ print("Warning: Index error accessing high_res_feats. Proceeding without them.")
335
+ except Exception as e:
336
+ print(f"Warning: Error processing high_res_features: {e}. Proceeding without them.")
337
+
338
+ low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
339
+ image_embeddings=image_embed, image_pe=image_pe,
340
+ sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings,
341
+ multimask_output=True, repeat_image=False, # repeat_image should be False for single image prediction
342
+ high_res_features=high_res_features, # Pass None if not available
343
+ )
344
+ # Postprocess masks to 1024x1024
345
+ prd_masks_1024 = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1]) # predictor._orig_hw should be (1024, 1024)
346
+ # Select the best mask based on predicted score
347
+ best_mask_idx = torch.argmax(prd_scores[0]).item()
348
+ # Apply sigmoid and thresholding
349
+ best_mask_1024_prob = torch.sigmoid(prd_masks_1024[:, best_mask_idx])
350
+ binary_mask_1024 = (best_mask_1024_prob > 0.5).cpu().numpy().squeeze() # Squeeze to get (H, W)
351
+ print(f" Model inference took {time.time() - start_inference_time:.2f}s")
352
+
353
+ # --- 5. Resize Mask to Original Dimensions ---
354
+ print(" Resizing mask to original dimensions...")
355
+ final_mask_resized = cv2.resize(
356
+ binary_mask_1024.astype(np.uint8), (original_w, original_h), interpolation=cv2.INTER_NEAREST
357
+ )
358
+
359
+ # --- 6. Format Mask for Display ---
360
+ # Mask for display (RGB)
361
+ output_mask_display = (final_mask_resized * 255).astype(np.uint8)
362
+ if len(output_mask_display.shape) == 2: # Ensure RGB for display consistency
363
+ output_mask_display = cv2.cvtColor(output_mask_display, cv2.COLOR_GRAY2RGB)
364
+
365
+ total_time = time.time() - start_total_time
366
+ print(f"--- Segmentation Complete (Total time: {total_time:.2f}s) ---")
367
+
368
+ # Return: mask for display, points state (unchanged)
369
+ return output_mask_display, points_state # No change needed here as it doesn't modify points
370
+
371
+
372
+ def process_upload(uploaded_image):
373
+ """Handles image upload: preprocesses, resizes, stores states."""
374
+ if uploaded_image is None:
375
+ # Clear everything including point counts
376
+ return None, None, None, [], None, get_point_counts_text([])
377
+
378
+ print("Image uploaded. Processing...")
379
+ # 1. Store original image
380
+ original_image = uploaded_image.copy()
381
+
382
+ # 2. Resize to 1024x1024 for preprocessing
383
+ image_resized_1024 = resize_image_fixed(original_image, 1024)
384
+ if image_resized_1024 is None:
385
+ gr.Error("Failed to resize image.")
386
+ return None, None, None, [], None
387
+
388
+ # 3. Preprocess the 1024x1024 image
389
+ preprocessed_1024 = preprocess_image_for_sam2(image_resized_1024)
390
+ if preprocessed_1024 is None:
391
+ gr.Error("Image preprocessing failed.")
392
+ return None, None, None, [], None
393
+
394
+ # Ensure preprocessed image is RGB for display
395
+ if len(preprocessed_1024.shape) == 2:
396
+ preprocessed_1024_display = cv2.cvtColor(preprocessed_1024, cv2.COLOR_GRAY2RGB)
397
+ else:
398
+ preprocessed_1024_display = preprocessed_1024.copy()
399
+
400
+ print("Image processed successfully.")
401
+ points_state = [] # Clear points on new upload
402
+ # Return:
403
+ # 1. Preprocessed image for display (interactive)
404
+ # 2. Preprocessed image for state
405
+ # 3. Original image for state
406
+ # 4. Cleared points state
407
+ # 5. Cleared mask display
408
+ # 6. Cleared point counts text
409
+ return preprocessed_1024_display, preprocessed_1024, original_image, points_state, None, get_point_counts_text(points_state)
410
+
411
+
412
+ def clear_all_outputs():
413
+ """Clears all input/output fields and states."""
414
+ print("Clearing all inputs and outputs.")
415
+ points_state = [] # Clear points
416
+ # Clear everything including point counts
417
+ return None, None, None, points_state, None, get_point_counts_text(points_state)
418
+
419
+
420
+ # --- Build Gradio Interface ---
421
+ css = """
422
+ #mask_display_container .gradio-image { height: 450px !important; object-fit: contain; }
423
+ #preprocessed_image_container .gradio-image { height: 450px !important; object-fit: contain; cursor: crosshair !important; }
424
+ #upload_container .gradio-image { height: 150px !important; object-fit: contain; } /* Smaller upload preview */
425
+ .output-col img { max-height: 450px; object-fit: contain; }
426
+ .control-col { min-width: 500px; } /* Wider control column */
427
+ .output-col { min-width: 500px; }
428
+ """
429
+
430
+ with gr.Blocks(css=css, title="Coronary Artery Segmentation (Fine-tuned SAM2)") as demo:
431
+ gr.Markdown("# Coronary Artery Segmentation using Fine-tuned SAM2")
432
+ gr.Markdown(
433
+ "**Let's find those arteries!**\n\n"
434
+ "1. Upload your Coronary X-ray Image.\n"
435
+ "2. The preprocessed image appears on the left. Time to guide the AI! Click directly on the image to add **Positive** (artery) or **Negative** (background) points.\n"
436
+ "3. Choose your fine-tuned SAM2 model.\n"
437
+ "4. Hit 'Run Segmentation' and watch the magic happen!\n"
438
+ "5. Download your predicted mask (the white area) using the download button on the mask image."
439
+ )
440
+
441
+ # --- States ---
442
+ points_state = gr.State([])
443
+ # State to store the original uploaded image (needed for final mask resizing)
444
+ original_image_state = gr.State(None)
445
+ # State to store the preprocessed 1024x1024 image data (used for drawing points and predictor input)
446
+ preprocessed_image_state = gr.State(None)
447
+
448
+
449
+ with gr.Row():
450
+ # --- Left Column (Controls & Preprocessed Image Interaction) ---
451
+ with gr.Column(scale=1, elem_classes="control-col"):
452
+ gr.Markdown("## 1. Upload & Controls")
453
+ # Keep upload separate and smaller
454
+ upload_image = gr.Image(
455
+ type="numpy", label="Upload Coronary X-ray Image",
456
+ height=150, elem_id="upload_container"
457
+ )
458
+ gr.Markdown("## 2. Add Points on Preprocessed Image")
459
+ # Interactive Preprocessed Image Display
460
+ preprocessed_image_display = gr.Image(
461
+ type="numpy", label="Click on Image to Add Points",
462
+ interactive=True, # Make this interactive
463
+ height=450, elem_id="preprocessed_image_container"
464
+ )
465
+ # Add Point Counter Display
466
+ point_counter_display = gr.Markdown(get_point_counts_text([]))
467
+
468
+ model_selector = gr.Dropdown(
469
+ choices=list(models_available.keys()),
470
+ label="Select SAM2 Model Variant",
471
+ value=list(models_available.keys())[-1] if models_available else None
472
+ )
473
+ prompt_type = gr.Radio(
474
+ ["Positive", "Negative"], label="Point Prompt Type", value="Positive"
475
+ )
476
+ with gr.Row():
477
+ clear_button = gr.Button("Clear Points")
478
+ undo_button = gr.Button("Undo Last Point")
479
+ run_button = gr.Button("Run Segmentation", variant="primary")
480
+ clear_all_button = gr.Button("Clear All") # Added Clear All
481
+
482
+ # --- Right Column (Output Mask) ---
483
+ with gr.Column(scale=1, elem_classes="output-col"):
484
+ gr.Markdown("## 3. Predicted Mask")
485
+ final_mask_display = gr.Image(
486
+ type="numpy", label="Predicted Binary Mask (White = Artery)",
487
+ interactive=False, height=450, elem_id="mask_display_container",
488
+ format="png" # Specify PNG format for download
489
+ )
490
+
491
+
492
+ # --- Define Interactions ---
493
+
494
+ # 1. Upload triggers preprocessing and display
495
+ upload_image.upload(
496
+ fn=process_upload,
497
+ inputs=[upload_image],
498
+ outputs=[
499
+ preprocessed_image_display, # Update interactive display
500
+ preprocessed_image_state, # Update state
501
+ original_image_state, # Update state
502
+ points_state, # Clear points
503
+ final_mask_display, # Clear mask display
504
+ point_counter_display # Clear point counts
505
+ ]
506
+ )
507
+
508
+ # 2. Clicking on preprocessed image adds points
509
+ preprocessed_image_display.select(
510
+ fn=add_point,
511
+ inputs=[preprocessed_image_state, points_state, prompt_type],
512
+ outputs=[
513
+ preprocessed_image_display, # Update display with points
514
+ points_state, # Update points state
515
+ point_counter_display # Update point counts
516
+ ]
517
+ )
518
+
519
+ # 3. Clear points button resets points and preprocessed display
520
+ clear_button.click(
521
+ fn=clear_points_and_display,
522
+ inputs=[preprocessed_image_state], # Needs the clean preprocessed image
523
+ outputs=[
524
+ preprocessed_image_display, # Reset display
525
+ points_state, # Clear points
526
+ final_mask_display, # Clear mask
527
+ point_counter_display # Reset point counts
528
+ ]
529
+ )
530
+
531
+ # 4. Undo button removes last point and updates preprocessed display
532
+ undo_button.click(
533
+ fn=undo_last_point,
534
+ inputs=[preprocessed_image_state, points_state], # Needs current preprocessed image
535
+ outputs=[
536
+ preprocessed_image_display, # Update display
537
+ points_state, # Update points state
538
+ point_counter_display # Update point counts
539
+ ]
540
+ )
541
+
542
+ # 5. Run segmentation (Outputs don't change point counts)
543
+ run_button.click(
544
+ fn=run_segmentation,
545
+ inputs=[
546
+ preprocessed_image_state, # Use preprocessed image data
547
+ original_image_state, # Needed for final resize dim
548
+ model_selector,
549
+ points_state # Points are relative to preprocessed
550
+ ],
551
+ outputs=[
552
+ final_mask_display, # Show the final mask
553
+ points_state # Pass points state (might be needed if run modifies it - currently doesn't)
554
+ ]
555
+ )
556
+
557
+ # 7. Clear All button
558
+ clear_all_button.click(
559
+ fn=clear_all_outputs,
560
+ inputs=[],
561
+ outputs=[
562
+ upload_image,
563
+ preprocessed_image_display,
564
+ preprocessed_image_state,
565
+ points_state,
566
+ final_mask_display,
567
+ point_counter_display # Reset point counts
568
+ ]
569
+ )
570
+
571
+
572
+ # --- Launch the App ---
573
+ if __name__ == "__main__":
574
+ print("Launching Gradio App...")
575
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ libgl1
2
+ libglib2.0-0
3
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ numpy
5
+ opencv-python-headless
6
+ Pillow
7
+ huggingface_hub
8
+ git+https://github.com/facebookresearch/segment-anything-2.git