baliddeki commited on
Commit
a0b8974
·
1 Parent(s): 6cd0958

api with gradido

Browse files
Files changed (2) hide show
  1. app.py +187 -18
  2. requirements.txt +5 -0
app.py CHANGED
@@ -11,6 +11,11 @@ from model import CombinedModel, ImageToTextProjector
11
  import pydicom
12
  import os
13
  import gc
 
 
 
 
 
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
@@ -42,26 +47,38 @@ image_transform = transforms.Compose([
42
  ])
43
 
44
  def dicom_to_image(file_bytes):
 
45
  dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
46
  pixel_array = dicom_file.pixel_array.astype(np.float32)
47
  pixel_array = ((pixel_array - pixel_array.min()) / pixel_array.ptp()) * 255.0
48
  pixel_array = pixel_array.astype(np.uint8)
49
  return Image.fromarray(pixel_array).convert("RGB")
50
 
51
- def predict(files):
52
- if not files:
 
53
  return "No images uploaded.", ""
54
 
55
  processed_imgs = []
56
- for file_obj in files:
57
- filename = file_obj.name.lower()
58
- if filename.endswith((".dcm", ".ima")):
59
- file_bytes = file_obj.read()
60
- img = dicom_to_image(file_bytes)
61
- else:
62
- img = Image.open(file_obj).convert("RGB")
63
- processed_imgs.append(img)
 
 
 
 
 
 
 
 
 
64
 
 
65
  n_frames = 16
66
  if len(processed_imgs) >= n_frames:
67
  images_sampled = [
@@ -71,28 +88,180 @@ def predict(files):
71
  else:
72
  images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs))
73
 
74
- tensor_imgs = [image_transform(i) for i in images_sampled]
 
75
  input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
76
 
 
77
  with torch.no_grad():
78
  class_logits, report, _ = combined_model(input_tensor)
79
  class_pred = torch.argmax(class_logits, dim=1).item()
80
  class_name = class_names[class_pred]
81
 
 
82
  gc.collect()
83
  if torch.cuda.is_available():
84
  torch.cuda.empty_cache()
85
 
86
  return class_name, report[0] if report else "No report generated."
87
 
88
- # Gradio Blocks (100% reliable approach)
89
- # Replace your Blocks interface with this simpler Interface approach
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  demo = gr.Interface(
91
- fn=predict,
92
- inputs=gr.File(file_count="multiple", file_types=[".dcm", ".jpg", ".jpeg", ".png", ".bmp"]),
93
- outputs=[gr.Textbox(label="Predicted Class"), gr.Textbox(label="Generated Report")],
 
 
 
 
 
 
 
94
  title="🩺 Phronesis Medical Report Generator",
95
- description="Upload CT scan images to generate a medical report"
 
 
 
 
 
 
 
 
96
  )
97
 
98
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import pydicom
12
  import os
13
  import gc
14
+ from fastapi import FastAPI, File, UploadFile, HTTPException
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from typing import List
17
+ import base64
18
+ from fastapi.responses import JSONResponse
19
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
 
47
  ])
48
 
49
  def dicom_to_image(file_bytes):
50
+ """Convert DICOM file bytes to PIL Image"""
51
  dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
52
  pixel_array = dicom_file.pixel_array.astype(np.float32)
53
  pixel_array = ((pixel_array - pixel_array.min()) / pixel_array.ptp()) * 255.0
54
  pixel_array = pixel_array.astype(np.uint8)
55
  return Image.fromarray(pixel_array).convert("RGB")
56
 
57
+ def process_images(file_data_list):
58
+ """Core image processing logic used by both Gradio and FastAPI"""
59
+ if not file_data_list:
60
  return "No images uploaded.", ""
61
 
62
  processed_imgs = []
63
+
64
+ for file_data in file_data_list:
65
+ filename = file_data.get('filename', '').lower()
66
+ file_content = file_data.get('content')
67
+
68
+ try:
69
+ if filename.endswith((".dcm", ".ima")):
70
+ img = dicom_to_image(file_content)
71
+ else:
72
+ img = Image.open(io.BytesIO(file_content)).convert("RGB")
73
+ processed_imgs.append(img)
74
+ except Exception as e:
75
+ print(f"Error processing file {filename}: {e}")
76
+ continue
77
+
78
+ if not processed_imgs:
79
+ return "No valid images processed.", ""
80
 
81
+ # Sample frames for video model
82
  n_frames = 16
83
  if len(processed_imgs) >= n_frames:
84
  images_sampled = [
 
88
  else:
89
  images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs))
90
 
91
+ # Transform images to tensors
92
+ tensor_imgs = [image_transform(img) for img in images_sampled]
93
  input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
94
 
95
+ # Model inference
96
  with torch.no_grad():
97
  class_logits, report, _ = combined_model(input_tensor)
98
  class_pred = torch.argmax(class_logits, dim=1).item()
99
  class_name = class_names[class_pred]
100
 
101
+ # Cleanup
102
  gc.collect()
103
  if torch.cuda.is_available():
104
  torch.cuda.empty_cache()
105
 
106
  return class_name, report[0] if report else "No report generated."
107
 
108
+ def predict_gradio(files):
109
+ """Gradio interface wrapper"""
110
+ if not files:
111
+ return "No images uploaded.", ""
112
+
113
+ file_data_list = []
114
+ for file_obj in files:
115
+ try:
116
+ file_content = file_obj.read() if hasattr(file_obj, 'read') else open(file_obj.name, 'rb').read()
117
+ file_data_list.append({
118
+ 'filename': file_obj.name if hasattr(file_obj, 'name') else str(file_obj),
119
+ 'content': file_content
120
+ })
121
+ except Exception as e:
122
+ print(f"Error reading file: {e}")
123
+ continue
124
+
125
+ return process_images(file_data_list)
126
+
127
+ # Create FastAPI app
128
+ app = FastAPI(
129
+ title="Phronesis ML API",
130
+ description="Medical Image Analysis API with Gradio Interface",
131
+ version="1.0.0"
132
+ )
133
+
134
+ # Add CORS middleware
135
+ app.add_middleware(
136
+ CORSMiddleware,
137
+ allow_origins=["*"],
138
+ allow_credentials=True,
139
+ allow_methods=["*"],
140
+ allow_headers=["*"],
141
+ )
142
+
143
+ @app.get("/")
144
+ async def root():
145
+ """Root endpoint"""
146
+ return {
147
+ "message": "Phronesis ML API",
148
+ "status": "running",
149
+ "endpoints": {
150
+ "predict": "/predict",
151
+ "health": "/health",
152
+ "gradio": "/gradio"
153
+ }
154
+ }
155
+
156
+ @app.get("/health")
157
+ async def health_check():
158
+ """Health check endpoint"""
159
+ return {
160
+ "status": "healthy",
161
+ "model_loaded": True,
162
+ "device": str(device)
163
+ }
164
+
165
+ @app.post("/predict")
166
+ async def predict_api(files: List[UploadFile] = File(...)):
167
+ """
168
+ API endpoint for medical image prediction
169
+
170
+ Args:
171
+ files: List of uploaded image files (DICOM, JPG, PNG, etc.)
172
+
173
+ Returns:
174
+ JSON response with predicted class and generated report
175
+ """
176
+ try:
177
+ if not files:
178
+ raise HTTPException(status_code=400, detail="No files uploaded")
179
+
180
+ # Process uploaded files
181
+ file_data_list = []
182
+ for file in files:
183
+ try:
184
+ content = await file.read()
185
+ file_data_list.append({
186
+ 'filename': file.filename or 'unknown',
187
+ 'content': content
188
+ })
189
+ except Exception as e:
190
+ print(f"Error reading uploaded file {file.filename}: {e}")
191
+ continue
192
+
193
+ if not file_data_list:
194
+ raise HTTPException(status_code=400, detail="No valid files processed")
195
+
196
+ # Get predictions
197
+ predicted_class, generated_report = process_images(file_data_list)
198
+
199
+ # Return results
200
+ return JSONResponse(content={
201
+ "status": "success",
202
+ "data": {
203
+ "predicted_class": predicted_class,
204
+ "generated_report": generated_report,
205
+ "processed_files": len(file_data_list)
206
+ }
207
+ })
208
+
209
+ except HTTPException:
210
+ raise
211
+ except Exception as e:
212
+ print(f"Prediction error: {e}")
213
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
214
+
215
+ @app.exception_handler(Exception)
216
+ async def global_exception_handler(request, exc):
217
+ """Global exception handler"""
218
+ return JSONResponse(
219
+ status_code=500,
220
+ content={
221
+ "status": "error",
222
+ "message": "Internal server error",
223
+ "detail": str(exc)
224
+ }
225
+ )
226
+
227
+ # Create Gradio interface
228
  demo = gr.Interface(
229
+ fn=predict_gradio,
230
+ inputs=gr.File(
231
+ file_count="multiple",
232
+ file_types=[".dcm", ".ima", ".jpg", ".jpeg", ".png", ".bmp"],
233
+ label="Upload Medical Images"
234
+ ),
235
+ outputs=[
236
+ gr.Textbox(label="Predicted Class"),
237
+ gr.Textbox(label="Generated Report", lines=5)
238
+ ],
239
  title="🩺 Phronesis Medical Report Generator",
240
+ description="""
241
+ Upload CT scan images to generate a medical report and classification.
242
+
243
+ **Supported formats:** DICOM (.dcm, .ima), JPEG, PNG, BMP
244
+
245
+ **API Endpoint:** `/predict` (POST)
246
+ """,
247
+ examples=[],
248
+ allow_flagging="never"
249
  )
250
 
251
+ # Mount Gradio app to FastAPI
252
+ app = gr.mount_gradio_app(app, demo, path="/gradio")
253
+
254
+ # Launch configuration
255
+ if __name__ == "__main__":
256
+ import uvicorn
257
+
258
+ # For local development
259
+ # uvicorn.run(app, host="0.0.0.0", port=7860)
260
+
261
+ # For Hugging Face Spaces
262
+ demo.launch(
263
+ server_name="0.0.0.0",
264
+ server_port=7860,
265
+ share=True,
266
+ show_error=True
267
+ )
requirements.txt CHANGED
@@ -14,3 +14,8 @@ scikit-learn==1.3.0
14
  tqdm==4.66.1
15
  sentencepiece==0.1.99
16
  pydicom==2.4.1
 
 
 
 
 
 
14
  tqdm==4.66.1
15
  sentencepiece==0.1.99
16
  pydicom==2.4.1
17
+
18
+
19
+
20
+ uvicorn[standard]
21
+ python-multipart