import os import gradio as gr import torch import numpy as np import cv2 import time import functools from PIL import Image from huggingface_hub import hf_hub_download # --- Configuration --- # Hugging Face model repositories and filenames HF_MODEL_CONFIG = { "SAM2 Hiera Tiny": { "repo_id": "astroanand/CoronarySAM2", "filename": "Coronary_Sam2_t.pt" }, "SAM2 Hiera Small": { "repo_id": "astroanand/CoronarySAM2", "filename": "Coronary_Sam2_s.pt" }, "SAM2 Hiera Base Plus": { "repo_id": "astroanand/CoronarySAM2", "filename": "Coronary_Sam2_b+.pt" }, "SAM2 Hiera Large": { "repo_id": "astroanand/CoronarySAM2", "filename": "Coronary_Sam2_l.pt" } } # Download and cache models from Hugging Face print("Checking and downloading models from Hugging Face...") models_available = {} for name, config in HF_MODEL_CONFIG.items(): try: print(f"Downloading {name} from {config['repo_id']}...") model_path = hf_hub_download( repo_id=config["repo_id"], filename=config["filename"], cache_dir="./hf_cache" ) models_available[name] = model_path print(f"✓ {name} downloaded successfully to {model_path}") except Exception as e: print(f"✗ Warning: Failed to download {name} from {config['repo_id']}: {e}") print(f" {name} will not be available in the dropdown.") if not models_available: print("Error: No valid models could be downloaded. Please check your internet connection and HF repository access.") # exit() # Or handle gracefully DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {DEVICE}") # Try importing SAM2 modules try: from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor except ImportError: print("Error: SAM2 modules not found. Make sure 'sam2' directory is in your Python path or installed.") exit() # --- Preprocessing Functions --- # ...existing code... def normalize_xray_image(image, kernel_size=(51,51), sigma=0): """Normalize X-ray image by applying Gaussian blur and intensity normalization.""" if image is None: return None is_color = len(image.shape) == 3 if is_color: gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) else: gray_image = image.copy() gray_image = gray_image.astype(float) blurred = gray_image.copy() for _ in range(5): # Reduced iterations blurred = cv2.GaussianBlur(blurred, kernel_size, sigma) mean_intensity = np.mean(blurred) factor_image = mean_intensity / (blurred + 1e-10) if is_color: normalized_image = image.copy().astype(float) for i in range(3): normalized_image[:,:,i] = normalized_image[:,:,i] * factor_image else: normalized_image = gray_image * factor_image return np.clip(normalized_image, 0, 255).astype(np.uint8) def apply_clahe(image_uint8): """Apply CLAHE for better vessel contrast. Expects uint8 input.""" if image_uint8 is None: return None is_color = len(image_uint8.shape) == 3 # --- ADJUST CLAHE STRENGTH HERE --- # Lower clipLimit reduces the contrast enhancement effect. # Original was 2.0. Try values like 1.5, 1.0, or even disable by setting it very low. clahe_clip_limit = 2.0 clahe_tile_grid_size = (8, 8) print(f" Applying CLAHE with clipLimit={clahe_clip_limit}, tileGridSize={clahe_tile_grid_size}") # --------------------------------- clahe = cv2.createCLAHE(clipLimit=clahe_clip_limit, tileGridSize=clahe_tile_grid_size) if is_color: lab = cv2.cvtColor(image_uint8, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) l_clahe = clahe.apply(l) lab_clahe = cv2.merge((l_clahe, a, b)) clahe_image_uint8 = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) else: clahe_image_uint8 = clahe.apply(image_uint8) # Return uint8 [0, 255] suitable for predictor.set_image return clahe_image_uint8 def preprocess_image_for_sam2(image_rgb_numpy): """Combined preprocessing: normalization + CLAHE for SAM2 input.""" if image_rgb_numpy is None: print("Preprocessing: Input image is None.") return None start_time = time.time() print("Preprocessing Step 1: Normalizing X-ray image...") if image_rgb_numpy.dtype != np.uint8: image_rgb_numpy = np.clip(image_rgb_numpy, 0, 255).astype(np.uint8) if len(image_rgb_numpy.shape) == 2: image_rgb_numpy = cv2.cvtColor(image_rgb_numpy, cv2.COLOR_GRAY2RGB) normalized_uint8 = normalize_xray_image(image_rgb_numpy) if normalized_uint8 is None: print("Preprocessing failed at normalization step.") return None print(f"Normalization done in {time.time() - start_time:.2f}s") start_time_clahe = time.time() print("Preprocessing Step 2: Applying CLAHE...") preprocessed_uint8 = apply_clahe(normalized_uint8) # CLAHE applied here if preprocessed_uint8 is None: print("Preprocessing failed at CLAHE step.") return None print(f"CLAHE done in {time.time() - start_time_clahe:.2f}s") print(f"Total preprocessing time: {time.time() - start_time:.2f}s") return preprocessed_uint8 # Return the image after all steps # --- Model Loading --- # ...existing code... @functools.lru_cache(maxsize=4) # Cache up to 4 models (one for each variant) def load_model(model_name): """Loads the specified SAM2 model and creates a predictor.""" print(f"\nAttempting to load model: {model_name}") if model_name not in models_available: # Check against available models print(f"Error: Model name '{model_name}' not found or checkpoint missing.") return None checkpoint_path = models_available[model_name] # Get path from available dict try: print(f" Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=DEVICE) if 'model_cfg' not in checkpoint: print(f"Error: 'model_cfg' key not found in checkpoint {checkpoint_path}.") return None model_cfg_name = checkpoint['model_cfg'] print(f" Using model config from checkpoint: {model_cfg_name}") sam2_model = build_sam2(model_cfg_name, checkpoint_path=None, device=DEVICE) if 'model_state_dict' not in checkpoint: print(f"Error: 'model_state_dict' not found in checkpoint {checkpoint_path}.") return None state_dict = checkpoint['model_state_dict'] new_state_dict = {} for k, v in state_dict.items(): name = k[7:] if k.startswith('module.') else k new_state_dict[name] = v sam2_model.load_state_dict(new_state_dict) print(" Successfully loaded fine-tuned model state_dict.") sam2_model.to(DEVICE) sam2_model.eval() predictor = SAM2ImagePredictor(sam2_model) print(f"Model '{model_name}' loaded successfully on {DEVICE}.") return predictor except Exception as e: print(f"Error loading model {model_name}: {e}") import traceback traceback.print_exc() return None # --- Utility Functions --- # ...existing code... def resize_image_fixed(image, target_size=1024): """Resizes image to a fixed square size (1024x1024).""" if image is None: return None return cv2.resize(image, (target_size, target_size), interpolation=cv2.INTER_LINEAR) def draw_points_on_image(image, points_state): """Draws points (green positive, red negative) on a copy of the image.""" if image is None: return image # Return original if no image draw_image = image.copy() if not points_state: return draw_image # Return copy if no points # Make points slightly larger and add a black border base_radius = max(4, int(min(image.shape[:2]) * 0.006)) # Slightly larger base radius border_thickness = 1 # Thickness of the black border radius_with_border = base_radius + border_thickness thickness = -1 # Filled circle for x, y, label in points_state: color = (0, 255, 0) if label == 1 else (255, 0, 0) center = (int(x), int(y)) # Draw black border circle first cv2.circle(draw_image, center, radius_with_border, (0, 0, 0), thickness) # Draw colored circle on top cv2.circle(draw_image, center, base_radius, color, thickness) return draw_image # --- Gradio UI Interaction Functions --- def get_point_counts_text(points_state): """Helper function to generate the point count markdown string.""" pos_count = sum(1 for _, _, label in points_state if label == 1) neg_count = sum(1 for _, _, label in points_state if label == 0) return f"**Points Added:** {pos_count} Positive, {neg_count} Negative" def add_point(preprocessed_image, points_state, point_type, evt: gr.SelectData): """Callback function when user clicks on the preprocessed image.""" if preprocessed_image is None: gr.Warning("Please upload and preprocess an image first.") # Return original image, points state, and existing counts text return preprocessed_image, points_state, get_point_counts_text(points_state) x, y = evt.index[0], evt.index[1] label = 1 if point_type == "Positive" else 0 # Store coordinates relative to the preprocessed image (1024x1024) points_state.append([x, y, label]) print(f"Added point: ({x}, {y}), Type: {'Positive' if label==1 else 'Negative'}, Total Points: {len(points_state)}") image_with_points = draw_points_on_image(preprocessed_image, points_state) # Return updated image, points state, and updated counts text return image_with_points, points_state, get_point_counts_text(points_state) def undo_last_point(preprocessed_image, points_state): """Removes the last added point and updates the preprocessed display image.""" if preprocessed_image is None: # Handle case where image is cleared # Return None image, points state, and counts text return None, points_state, get_point_counts_text(points_state) if not points_state: print("No points to undo.") # Return the current preprocessed image without changes if no points return preprocessed_image, points_state, get_point_counts_text(points_state) removed_point = points_state.pop() print(f"Removed point: {removed_point}, Remaining Points: {len(points_state)}") image_with_points = draw_points_on_image(preprocessed_image, points_state) # Return updated image, points state, and updated counts text return image_with_points, points_state, get_point_counts_text(points_state) def clear_points_and_display(preprocessed_image_state): """Clears points and resets the preprocessed display image.""" print("Clearing points and resetting preprocessed display.") points_state = [] # Clear points # Return the stored preprocessed image without points, clear points state, clear mask, clear counts text return preprocessed_image_state, points_state, None, get_point_counts_text(points_state) def run_segmentation(preprocessed_image_state, original_image_state, model_name, points_state): """Runs SAM2 segmentation using points on the preprocessed image.""" start_total_time = time.time() # Initialize return values output_mask_display = None if preprocessed_image_state is None or original_image_state is None: gr.Warning("Please upload an image first.") return output_mask_display, points_state print(f"\n--- Running Segmentation ---") print(f" Model Selected: {model_name}") print(f" Number of points: {len(points_state)}") # --- 1. Load Model --- predictor = load_model(model_name) if predictor is None: gr.Error(f"Failed to load model '{model_name}'. Check logs and paths.") return output_mask_display, points_state # --- 2. Use Preprocessed Image --- # The image is already preprocessed and resized to 1024x1024 image_for_predictor = preprocessed_image_state original_h, original_w = original_image_state.shape[:2] # Get original dims for final resize print(f" Using preprocessed image (1024x1024) for predictor.") print(f" Original image size for final mask resize: {original_w}x{original_h}") print(" Setting preprocessed image in predictor...") start_set_image = time.time() # Feed the preprocessed image (which is already 1024x1024 uint8) to SAM predictor.set_image(image_for_predictor) print(f" predictor.set_image took {time.time() - start_set_image:.2f}s") # --- 3. Prepare Prompts (No Scaling Needed) --- if not points_state: # Use center point if no points provided center_x, center_y = 512, 512 point_coords = np.array([[[center_x, center_y]]]) point_labels = np.array([1]) print(" No points provided. Using center point (512, 512).") else: # Points are already relative to the 1024x1024 preprocessed image point_coords_list = [[x, y] for x, y, label in points_state] labels_list = [label for x, y, label in points_state] point_coords = np.array([point_coords_list]) point_labels = np.array(labels_list) print(f" Using {len(points_state)} provided points (coords relative to 1024x1024).") point_coords_torch = torch.tensor(point_coords, dtype=torch.float32).to(DEVICE) point_labels_torch = torch.tensor(point_labels, dtype=torch.float32).unsqueeze(0).to(DEVICE) # Add batch dim # --- 4. Run Model Inference --- print(" Running model inference...") start_inference_time = time.time() with torch.no_grad(): sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( points=(point_coords_torch, point_labels_torch), boxes=None, masks=None ) if predictor._features is None: gr.Error("Image features not computed. Predictor might not have been set correctly.") return output_mask_display, points_state # Ensure features are accessed correctly image_embed = predictor._features["image_embed"][-1].unsqueeze(0) image_pe = predictor.model.sam_prompt_encoder.get_dense_pe() # Handle potential missing high_res_features key gracefully high_res_features = None if "high_res_feats" in predictor._features and predictor._features["high_res_feats"]: try: high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]] except IndexError: print("Warning: Index error accessing high_res_feats. Proceeding without them.") except Exception as e: print(f"Warning: Error processing high_res_features: {e}. Proceeding without them.") low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( image_embeddings=image_embed, image_pe=image_pe, sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=True, repeat_image=False, # repeat_image should be False for single image prediction high_res_features=high_res_features, # Pass None if not available ) # Postprocess masks to 1024x1024 prd_masks_1024 = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1]) # predictor._orig_hw should be (1024, 1024) # Select the best mask based on predicted score best_mask_idx = torch.argmax(prd_scores[0]).item() # Apply sigmoid and thresholding best_mask_1024_prob = torch.sigmoid(prd_masks_1024[:, best_mask_idx]) binary_mask_1024 = (best_mask_1024_prob > 0.5).cpu().numpy().squeeze() # Squeeze to get (H, W) print(f" Model inference took {time.time() - start_inference_time:.2f}s") # --- 5. Resize Mask to Original Dimensions --- print(" Resizing mask to original dimensions...") final_mask_resized = cv2.resize( binary_mask_1024.astype(np.uint8), (original_w, original_h), interpolation=cv2.INTER_NEAREST ) # --- 6. Format Mask for Display --- # Mask for display (RGB) output_mask_display = (final_mask_resized * 255).astype(np.uint8) if len(output_mask_display.shape) == 2: # Ensure RGB for display consistency output_mask_display = cv2.cvtColor(output_mask_display, cv2.COLOR_GRAY2RGB) total_time = time.time() - start_total_time print(f"--- Segmentation Complete (Total time: {total_time:.2f}s) ---") # Return: mask for display, points state (unchanged) return output_mask_display, points_state # No change needed here as it doesn't modify points def process_upload(uploaded_image): """Handles image upload: preprocesses, resizes, stores states.""" if uploaded_image is None: # Clear everything including point counts return None, None, None, [], None, get_point_counts_text([]) print("Image uploaded. Processing...") # 1. Store original image original_image = uploaded_image.copy() # 2. Resize to 1024x1024 for preprocessing image_resized_1024 = resize_image_fixed(original_image, 1024) if image_resized_1024 is None: gr.Error("Failed to resize image.") return None, None, None, [], None # 3. Preprocess the 1024x1024 image preprocessed_1024 = preprocess_image_for_sam2(image_resized_1024) if preprocessed_1024 is None: gr.Error("Image preprocessing failed.") return None, None, None, [], None # Ensure preprocessed image is RGB for display if len(preprocessed_1024.shape) == 2: preprocessed_1024_display = cv2.cvtColor(preprocessed_1024, cv2.COLOR_GRAY2RGB) else: preprocessed_1024_display = preprocessed_1024.copy() print("Image processed successfully.") points_state = [] # Clear points on new upload # Return: # 1. Preprocessed image for display (interactive) # 2. Preprocessed image for state # 3. Original image for state # 4. Cleared points state # 5. Cleared mask display # 6. Cleared point counts text return preprocessed_1024_display, preprocessed_1024, original_image, points_state, None, get_point_counts_text(points_state) def clear_all_outputs(): """Clears all input/output fields and states.""" print("Clearing all inputs and outputs.") points_state = [] # Clear points # Clear everything including point counts return None, None, None, points_state, None, get_point_counts_text(points_state) # --- Build Gradio Interface --- css = """ #mask_display_container .gradio-image { height: 450px !important; object-fit: contain; } #preprocessed_image_container .gradio-image { height: 450px !important; object-fit: contain; cursor: crosshair !important; } #upload_container .gradio-image { height: 150px !important; object-fit: contain; } /* Smaller upload preview */ .output-col img { max-height: 450px; object-fit: contain; } .control-col { min-width: 500px; } /* Wider control column */ .output-col { min-width: 500px; } """ with gr.Blocks(css=css, title="Coronary Artery Segmentation (Fine-tuned SAM2)") as demo: gr.Markdown("# Coronary Artery Segmentation using Fine-tuned SAM2") gr.Markdown( "**Let's find those arteries!**\n\n" "1. Upload your Coronary X-ray Image.\n" "2. The preprocessed image appears on the left. Time to guide the AI! Click directly on the image to add **Positive** (artery) or **Negative** (background) points.\n" "3. Choose your fine-tuned SAM2 model.\n" "4. Hit 'Run Segmentation' and watch the magic happen!\n" "5. Download your predicted mask (the white area) using the download button on the mask image." ) # --- States --- points_state = gr.State([]) # State to store the original uploaded image (needed for final mask resizing) original_image_state = gr.State(None) # State to store the preprocessed 1024x1024 image data (used for drawing points and predictor input) preprocessed_image_state = gr.State(None) with gr.Row(): # --- Left Column (Controls & Preprocessed Image Interaction) --- with gr.Column(scale=1, elem_classes="control-col"): gr.Markdown("## 1. Upload & Controls") # Keep upload separate and smaller upload_image = gr.Image( type="numpy", label="Upload Coronary X-ray Image", height=150, elem_id="upload_container" ) gr.Markdown("## 2. Add Points on Preprocessed Image") # Interactive Preprocessed Image Display preprocessed_image_display = gr.Image( type="numpy", label="Click on Image to Add Points", interactive=True, # Make this interactive height=450, elem_id="preprocessed_image_container" ) # Add Point Counter Display point_counter_display = gr.Markdown(get_point_counts_text([])) model_selector = gr.Dropdown( choices=list(models_available.keys()), label="Select SAM2 Model Variant", value=list(models_available.keys())[-1] if models_available else None ) prompt_type = gr.Radio( ["Positive", "Negative"], label="Point Prompt Type", value="Positive" ) with gr.Row(): clear_button = gr.Button("Clear Points") undo_button = gr.Button("Undo Last Point") run_button = gr.Button("Run Segmentation", variant="primary") clear_all_button = gr.Button("Clear All") # Added Clear All # --- Right Column (Output Mask) --- with gr.Column(scale=1, elem_classes="output-col"): gr.Markdown("## 3. Predicted Mask") final_mask_display = gr.Image( type="numpy", label="Predicted Binary Mask (White = Artery)", interactive=False, height=450, elem_id="mask_display_container", format="png" # Specify PNG format for download ) # --- Define Interactions --- # 1. Upload triggers preprocessing and display upload_image.upload( fn=process_upload, inputs=[upload_image], outputs=[ preprocessed_image_display, # Update interactive display preprocessed_image_state, # Update state original_image_state, # Update state points_state, # Clear points final_mask_display, # Clear mask display point_counter_display # Clear point counts ] ) # 2. Clicking on preprocessed image adds points preprocessed_image_display.select( fn=add_point, inputs=[preprocessed_image_state, points_state, prompt_type], outputs=[ preprocessed_image_display, # Update display with points points_state, # Update points state point_counter_display # Update point counts ] ) # 3. Clear points button resets points and preprocessed display clear_button.click( fn=clear_points_and_display, inputs=[preprocessed_image_state], # Needs the clean preprocessed image outputs=[ preprocessed_image_display, # Reset display points_state, # Clear points final_mask_display, # Clear mask point_counter_display # Reset point counts ] ) # 4. Undo button removes last point and updates preprocessed display undo_button.click( fn=undo_last_point, inputs=[preprocessed_image_state, points_state], # Needs current preprocessed image outputs=[ preprocessed_image_display, # Update display points_state, # Update points state point_counter_display # Update point counts ] ) # 5. Run segmentation (Outputs don't change point counts) run_button.click( fn=run_segmentation, inputs=[ preprocessed_image_state, # Use preprocessed image data original_image_state, # Needed for final resize dim model_selector, points_state # Points are relative to preprocessed ], outputs=[ final_mask_display, # Show the final mask points_state # Pass points state (might be needed if run modifies it - currently doesn't) ] ) # 7. Clear All button clear_all_button.click( fn=clear_all_outputs, inputs=[], outputs=[ upload_image, preprocessed_image_display, preprocessed_image_state, points_state, final_mask_display, point_counter_display # Reset point counts ] ) # --- Launch the App --- if __name__ == "__main__": print("Launching Gradio App...") demo.launch()