Spaces:
Sleeping
Sleeping
| #app.py | |
| import os | |
| import io | |
| import uvicorn | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| import pydicom | |
| import gc | |
| from model import CombinedModel, ImageToTextProjector | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(request: Request): | |
| return {"message": "Welcome to Phronesis"} | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def dicom_to_png(dicom_data): | |
| try: | |
| dicom_file = pydicom.dcmread(dicom_data) | |
| if not hasattr(dicom_file, 'PixelData'): | |
| raise HTTPException(status_code=400, detail="No pixel data in DICOM file.") | |
| pixel_array = dicom_file.pixel_array.astype(np.float32) | |
| pixel_array = ((pixel_array - pixel_array.min()) / (pixel_array.ptp())) * 255.0 | |
| pixel_array = pixel_array.astype(np.uint8) | |
| img = Image.fromarray(pixel_array).convert("L") | |
| return img | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error converting DICOM to PNG: {e}") | |
| # Set up secure model initialization | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| if not HF_TOKEN: | |
| raise ValueError("Missing Hugging Face token in environment variables.") | |
| try: | |
| report_generator_tokenizer = AutoTokenizer.from_pretrained( | |
| "KYAGABA/combined-multimodal-model", | |
| token=HF_TOKEN if HF_TOKEN else None | |
| ) | |
| video_model = models.video.r3d_18(weights="KINETICS400_V1") | |
| video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512) | |
| report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base") | |
| projector = ImageToTextProjector(512, report_generator.config.d_model) | |
| num_classes = 4 | |
| combined_model = CombinedModel(video_model, report_generator, num_classes, projector, report_generator_tokenizer) | |
| model_file = hf_hub_download("KYAGABA/combined-multimodal-model", "pytorch_model.bin", token=HF_TOKEN) | |
| state_dict = torch.load(model_file, map_location=device) | |
| combined_model.load_state_dict(state_dict) | |
| combined_model.eval() | |
| except Exception as e: | |
| raise SystemExit(f"Error loading models: {e}") | |
| image_transform = transforms.Compose([ | |
| transforms.Resize((112, 112)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]) | |
| ]) | |
| class_names = ["acute", "normal", "chronic", "lacunar"] | |
| async def predict(files: list[UploadFile]): | |
| print(f"Received {len(files)} files") | |
| n_frames = 16 | |
| images = [] | |
| for file in files: | |
| ext = file.filename.split('.')[-1].lower() | |
| try: | |
| if ext in ['dcm', 'ima']: | |
| dicom_img = dicom_to_png(await file.read()) | |
| images.append(dicom_img.convert("RGB")) | |
| elif ext in ['png', 'jpeg', 'jpg']: | |
| img = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
| images.append(img) | |
| else: | |
| raise HTTPException(status_code=400, detail="Unsupported file type.") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing file {file.filename}: {e}") | |
| if not images: | |
| return JSONResponse(content={"error": "No valid images provided."}, status_code=400) | |
| if len(images) >= n_frames: | |
| images_sampled = [images[i] for i in np.linspace(0, len(images) - 1, n_frames, dtype=int)] | |
| else: | |
| images_sampled = images + [images[-1]] * (n_frames - len(images)) | |
| image_tensors = [image_transform(img) for img in images_sampled] | |
| images_tensor = torch.stack(image_tensors).permute(1, 0, 2, 3).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| class_outputs, generated_report, _ = combined_model(images_tensor) | |
| predicted_class = torch.argmax(class_outputs, dim=1).item() | |
| predicted_class_name = class_names[predicted_class] | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return { | |
| "predicted_class": predicted_class_name, | |
| "generated_report": generated_report[0] if generated_report else "No report generated." | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) |