Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- Dockerfile +16 -0
- README.md +7 -6
- app.py +129 -0
- model.py +56 -0
- requirements.txt +16 -0
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use an official Python runtime as a base image
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Copy the current directory contents into the container at /app
|
| 8 |
+
COPY . /app
|
| 9 |
+
|
| 10 |
+
# Install the dependencies
|
| 11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 12 |
+
|
| 13 |
+
# Expose port 8000
|
| 14 |
+
EXPOSE 7860
|
| 15 |
+
# CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
| 16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
---
|
| 2 |
-
title: Phronesis
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: green
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk:
|
|
|
|
|
|
|
| 7 |
pinned: false
|
| 8 |
-
|
| 9 |
-
short_description: ML endpoints
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Phronesis
|
| 3 |
+
emoji: 🌖
|
| 4 |
colorFrom: green
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.4.0
|
| 8 |
+
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
short_description: 'REPORT GEN AND CLASSIFICATION MODEL '
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#app.py
|
| 2 |
+
import os
|
| 3 |
+
import io
|
| 4 |
+
import uvicorn
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 8 |
+
from fastapi.responses import JSONResponse
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 10 |
+
from torchvision import models, transforms
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
+
import pydicom
|
| 15 |
+
import gc
|
| 16 |
+
from model import CombinedModel, ImageToTextProjector
|
| 17 |
+
from fastapi import FastAPI, Request
|
| 18 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
app = FastAPI()
|
| 22 |
+
|
| 23 |
+
app.add_middleware(
|
| 24 |
+
CORSMiddleware,
|
| 25 |
+
allow_origins=["*"],
|
| 26 |
+
allow_credentials=True,
|
| 27 |
+
allow_methods=["*"],
|
| 28 |
+
allow_headers=["*"],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
@app.get("/")
|
| 32 |
+
async def root(request: Request):
|
| 33 |
+
return {"message": "Welcome to Phronesis"}
|
| 34 |
+
|
| 35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
|
| 37 |
+
def dicom_to_png(dicom_data):
|
| 38 |
+
try:
|
| 39 |
+
dicom_file = pydicom.dcmread(dicom_data)
|
| 40 |
+
if not hasattr(dicom_file, 'PixelData'):
|
| 41 |
+
raise HTTPException(status_code=400, detail="No pixel data in DICOM file.")
|
| 42 |
+
|
| 43 |
+
pixel_array = dicom_file.pixel_array.astype(np.float32)
|
| 44 |
+
pixel_array = ((pixel_array - pixel_array.min()) / (pixel_array.ptp())) * 255.0
|
| 45 |
+
pixel_array = pixel_array.astype(np.uint8)
|
| 46 |
+
|
| 47 |
+
img = Image.fromarray(pixel_array).convert("L")
|
| 48 |
+
return img
|
| 49 |
+
except Exception as e:
|
| 50 |
+
raise HTTPException(status_code=500, detail=f"Error converting DICOM to PNG: {e}")
|
| 51 |
+
|
| 52 |
+
# Set up secure model initialization
|
| 53 |
+
HF_TOKEN = os.getenv('HF_TOKEN')
|
| 54 |
+
if not HF_TOKEN:
|
| 55 |
+
raise ValueError("Missing Hugging Face token in environment variables.")
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
report_generator_tokenizer = AutoTokenizer.from_pretrained(
|
| 59 |
+
"KYAGABA/combined-multimodal-model",
|
| 60 |
+
token=HF_TOKEN if HF_TOKEN else None
|
| 61 |
+
)
|
| 62 |
+
video_model = models.video.r3d_18(weights="KINETICS400_V1")
|
| 63 |
+
video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
|
| 64 |
+
report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")
|
| 65 |
+
projector = ImageToTextProjector(512, report_generator.config.d_model)
|
| 66 |
+
num_classes = 4
|
| 67 |
+
combined_model = CombinedModel(video_model, report_generator, num_classes, projector, report_generator_tokenizer)
|
| 68 |
+
model_file = hf_hub_download("KYAGABA/combined-multimodal-model", "pytorch_model.bin", token=HF_TOKEN)
|
| 69 |
+
state_dict = torch.load(model_file, map_location=device)
|
| 70 |
+
combined_model.load_state_dict(state_dict)
|
| 71 |
+
combined_model.eval()
|
| 72 |
+
except Exception as e:
|
| 73 |
+
raise SystemExit(f"Error loading models: {e}")
|
| 74 |
+
|
| 75 |
+
image_transform = transforms.Compose([
|
| 76 |
+
transforms.Resize((112, 112)),
|
| 77 |
+
transforms.ToTensor(),
|
| 78 |
+
transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989])
|
| 79 |
+
])
|
| 80 |
+
|
| 81 |
+
class_names = ["acute", "normal", "chronic", "lacunar"]
|
| 82 |
+
|
| 83 |
+
@app.post("/predict/")
|
| 84 |
+
async def predict(files: list[UploadFile]):
|
| 85 |
+
print(f"Received {len(files)} files")
|
| 86 |
+
n_frames = 16
|
| 87 |
+
images = []
|
| 88 |
+
|
| 89 |
+
for file in files:
|
| 90 |
+
ext = file.filename.split('.')[-1].lower()
|
| 91 |
+
try:
|
| 92 |
+
if ext in ['dcm', 'ima']:
|
| 93 |
+
dicom_img = dicom_to_png(await file.read())
|
| 94 |
+
images.append(dicom_img.convert("RGB"))
|
| 95 |
+
elif ext in ['png', 'jpeg', 'jpg']:
|
| 96 |
+
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
|
| 97 |
+
images.append(img)
|
| 98 |
+
else:
|
| 99 |
+
raise HTTPException(status_code=400, detail="Unsupported file type.")
|
| 100 |
+
except Exception as e:
|
| 101 |
+
raise HTTPException(status_code=500, detail=f"Error processing file {file.filename}: {e}")
|
| 102 |
+
|
| 103 |
+
if not images:
|
| 104 |
+
return JSONResponse(content={"error": "No valid images provided."}, status_code=400)
|
| 105 |
+
|
| 106 |
+
if len(images) >= n_frames:
|
| 107 |
+
images_sampled = [images[i] for i in np.linspace(0, len(images) - 1, n_frames, dtype=int)]
|
| 108 |
+
else:
|
| 109 |
+
images_sampled = images + [images[-1]] * (n_frames - len(images))
|
| 110 |
+
|
| 111 |
+
image_tensors = [image_transform(img) for img in images_sampled]
|
| 112 |
+
images_tensor = torch.stack(image_tensors).permute(1, 0, 2, 3).unsqueeze(0).to(device)
|
| 113 |
+
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
class_outputs, generated_report, _ = combined_model(images_tensor)
|
| 116 |
+
predicted_class = torch.argmax(class_outputs, dim=1).item()
|
| 117 |
+
predicted_class_name = class_names[predicted_class]
|
| 118 |
+
|
| 119 |
+
gc.collect()
|
| 120 |
+
if torch.cuda.is_available():
|
| 121 |
+
torch.cuda.empty_cache()
|
| 122 |
+
|
| 123 |
+
return {
|
| 124 |
+
"predicted_class": predicted_class_name,
|
| 125 |
+
"generated_report": generated_report[0] if generated_report else "No report generated."
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
|
model.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import AutoModelForSeq2SeqLM
|
| 6 |
+
|
| 7 |
+
class ImageToTextProjector(nn.Module):
|
| 8 |
+
def __init__(self, image_embedding_dim, text_embedding_dim):
|
| 9 |
+
super(ImageToTextProjector, self).__init__()
|
| 10 |
+
self.fc = nn.Linear(image_embedding_dim, text_embedding_dim)
|
| 11 |
+
self.activation = nn.ReLU()
|
| 12 |
+
self.dropout = nn.Dropout(p=0.5)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
x = self.fc(x)
|
| 16 |
+
x = self.activation(x)
|
| 17 |
+
x = self.dropout(x)
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
class CombinedModel(nn.Module):
|
| 21 |
+
def __init__(self, video_model, report_generator, num_classes, projector, tokenizer):
|
| 22 |
+
super(CombinedModel, self).__init__()
|
| 23 |
+
self.video_model = video_model
|
| 24 |
+
self.report_generator = report_generator
|
| 25 |
+
self.classifier = nn.Linear(512, num_classes)
|
| 26 |
+
self.projector = projector
|
| 27 |
+
self.dropout = nn.Dropout(p=0.5)
|
| 28 |
+
self.tokenizer = tokenizer # Store tokenizer
|
| 29 |
+
|
| 30 |
+
def forward(self, images, labels=None):
|
| 31 |
+
video_embeddings = self.video_model(images)
|
| 32 |
+
video_embeddings = self.dropout(video_embeddings)
|
| 33 |
+
class_outputs = self.classifier(video_embeddings)
|
| 34 |
+
projected_embeddings = self.projector(video_embeddings)
|
| 35 |
+
encoder_inputs = projected_embeddings.unsqueeze(1)
|
| 36 |
+
|
| 37 |
+
if labels is not None:
|
| 38 |
+
outputs = self.report_generator(
|
| 39 |
+
inputs_embeds=encoder_inputs,
|
| 40 |
+
labels=labels
|
| 41 |
+
)
|
| 42 |
+
gen_loss = outputs.loss
|
| 43 |
+
generated_report = None
|
| 44 |
+
else:
|
| 45 |
+
generated_report_ids = self.report_generator.generate(
|
| 46 |
+
inputs_embeds=encoder_inputs,
|
| 47 |
+
max_length=512,
|
| 48 |
+
num_beams=4,
|
| 49 |
+
early_stopping=True
|
| 50 |
+
)
|
| 51 |
+
generated_report = self.tokenizer.batch_decode(
|
| 52 |
+
generated_report_ids, skip_special_tokens=True
|
| 53 |
+
)
|
| 54 |
+
gen_loss = None
|
| 55 |
+
|
| 56 |
+
return class_outputs, generated_report, gen_loss
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
torch==2.0.1
|
| 3 |
+
torchvision==0.15.2
|
| 4 |
+
transformers==4.44.2
|
| 5 |
+
gradio==5.0
|
| 6 |
+
numpy==1.26.2
|
| 7 |
+
Pillow==10.0.1
|
| 8 |
+
fastapi
|
| 9 |
+
# Additional dependencies
|
| 10 |
+
huggingface_hub==0.25.1 # Compatible with both transformers and gradio
|
| 11 |
+
torchmetrics==1.5.1
|
| 12 |
+
nltk==3.8.1
|
| 13 |
+
scikit-learn==1.3.0
|
| 14 |
+
tqdm==4.66.1
|
| 15 |
+
sentencepiece==0.1.99
|
| 16 |
+
pydicom==2.4.1
|