Spaces:
Sleeping
Sleeping
| # app.py | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import gradio as gr | |
| from torchvision import models, transforms | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from huggingface_hub import hf_hub_download | |
| from model import CombinedModel, ImageToTextProjector | |
| import pydicom | |
| import os | |
| import gc | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import List | |
| import base64 | |
| from fastapi.responses import JSONResponse | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
| # Model loading | |
| tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN) | |
| video_model = models.video.r3d_18(weights="KINETICS400_V1") | |
| video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512) | |
| report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base") | |
| projector = ImageToTextProjector(512, report_generator.config.d_model) | |
| num_classes = 4 | |
| class_names = ["acute", "normal", "chronic", "lacunar"] | |
| combined_model = CombinedModel(video_model, report_generator, num_classes, projector, tokenizer) | |
| model_file = hf_hub_download("baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN) | |
| state_dict = torch.load(model_file, map_location=device) | |
| combined_model.load_state_dict(state_dict) | |
| combined_model.to(device) | |
| combined_model.eval() | |
| image_transform = transforms.Compose([ | |
| transforms.Resize((112, 112)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]), | |
| ]) | |
| def dicom_to_image(file_bytes): | |
| """Convert DICOM file bytes to PIL Image""" | |
| dicom_file = pydicom.dcmread(io.BytesIO(file_bytes)) | |
| pixel_array = dicom_file.pixel_array.astype(np.float32) | |
| pixel_array = ((pixel_array - pixel_array.min()) / pixel_array.ptp()) * 255.0 | |
| pixel_array = pixel_array.astype(np.uint8) | |
| return Image.fromarray(pixel_array).convert("RGB") | |
| def process_images(file_data_list): | |
| """Core image processing logic used by both Gradio and FastAPI""" | |
| if not file_data_list: | |
| return "No images uploaded.", "" | |
| processed_imgs = [] | |
| for file_data in file_data_list: | |
| filename = file_data.get('filename', '').lower() | |
| file_content = file_data.get('content') | |
| try: | |
| if filename.endswith((".dcm", ".ima")): | |
| img = dicom_to_image(file_content) | |
| else: | |
| img = Image.open(io.BytesIO(file_content)).convert("RGB") | |
| processed_imgs.append(img) | |
| except Exception as e: | |
| print(f"Error processing file {filename}: {e}") | |
| continue | |
| if not processed_imgs: | |
| return "No valid images processed.", "" | |
| # Sample frames for video model | |
| n_frames = 16 | |
| if len(processed_imgs) >= n_frames: | |
| images_sampled = [ | |
| processed_imgs[i] | |
| for i in np.linspace(0, len(processed_imgs)-1, n_frames, dtype=int) | |
| ] | |
| else: | |
| images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs)) | |
| # Transform images to tensors | |
| tensor_imgs = [image_transform(img) for img in images_sampled] | |
| input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device) | |
| # Model inference | |
| with torch.no_grad(): | |
| class_logits, report, _ = combined_model(input_tensor) | |
| class_pred = torch.argmax(class_logits, dim=1).item() | |
| class_name = class_names[class_pred] | |
| # Cleanup | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return class_name, report[0] if report else "No report generated." | |
| def predict_gradio(files): | |
| """Gradio interface wrapper""" | |
| if not files: | |
| return "No images uploaded.", "" | |
| file_data_list = [] | |
| for file_obj in files: | |
| try: | |
| file_content = file_obj.read() if hasattr(file_obj, 'read') else open(file_obj.name, 'rb').read() | |
| file_data_list.append({ | |
| 'filename': file_obj.name if hasattr(file_obj, 'name') else str(file_obj), | |
| 'content': file_content | |
| }) | |
| except Exception as e: | |
| print(f"Error reading file: {e}") | |
| continue | |
| return process_images(file_data_list) | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Phronesis ML API", | |
| description="Medical Image Analysis API with Gradio Interface", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "Phronesis ML API", | |
| "status": "running", | |
| "endpoints": { | |
| "predict": "/predict", | |
| "health": "/health", | |
| "gradio": "/gradio" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": True, | |
| "device": str(device) | |
| } | |
| async def predict_api(files: List[UploadFile] = File(...)): | |
| """ | |
| API endpoint for medical image prediction | |
| Args: | |
| files: List of uploaded image files (DICOM, JPG, PNG, etc.) | |
| Returns: | |
| JSON response with predicted class and generated report | |
| """ | |
| try: | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No files uploaded") | |
| # Process uploaded files | |
| file_data_list = [] | |
| for file in files: | |
| try: | |
| content = await file.read() | |
| file_data_list.append({ | |
| 'filename': file.filename or 'unknown', | |
| 'content': content | |
| }) | |
| except Exception as e: | |
| print(f"Error reading uploaded file {file.filename}: {e}") | |
| continue | |
| if not file_data_list: | |
| raise HTTPException(status_code=400, detail="No valid files processed") | |
| # Get predictions | |
| predicted_class, generated_report = process_images(file_data_list) | |
| # Return results | |
| return JSONResponse(content={ | |
| "status": "success", | |
| "data": { | |
| "predicted_class": predicted_class, | |
| "generated_report": generated_report, | |
| "processed_files": len(file_data_list) | |
| } | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def global_exception_handler(request, exc): | |
| """Global exception handler""" | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "status": "error", | |
| "message": "Internal server error", | |
| "detail": str(exc) | |
| } | |
| ) | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=predict_gradio, | |
| inputs=gr.File( | |
| file_count="multiple", | |
| file_types=[".dcm", ".ima", ".jpg", ".jpeg", ".png", ".bmp"], | |
| label="Upload Medical Images" | |
| ), | |
| outputs=[ | |
| gr.Textbox(label="Predicted Class"), | |
| gr.Textbox(label="Generated Report", lines=5) | |
| ], | |
| title="🩺 Phronesis Medical Report Generator", | |
| description=""" | |
| Upload CT scan images to generate a medical report and classification. | |
| **Supported formats:** DICOM (.dcm, .ima), JPEG, PNG, BMP | |
| **API Endpoint:** `/predict` (POST) | |
| """, | |
| examples=[], | |
| allow_flagging="never" | |
| ) | |
| # Mount Gradio app to FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/gradio") | |
| # Launch configuration | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # For local development | |
| # uvicorn.run(app, host="0.0.0.0", port=7860) | |
| # For Hugging Face Spaces | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True | |
| ) |