Spaces:
Sleeping
Sleeping
| from typing import List, Optional, Union, Literal | |
| from fastapi import FastAPI, Body | |
| from pydantic import BaseModel | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| from PIL import Image as PILImage | |
| import torch | |
| import base64 | |
| import io | |
| import os | |
| from starlette.responses import FileResponse | |
| app = FastAPI(docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json") | |
| # Initialize model and processor | |
| MODEL_NAME = "bytedance-research/UI-TARS-7B-DPO" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(device) # Use float16 with low CPU memory usage | |
| except RuntimeError as e: | |
| if "CUDA out of memory" in str(e): | |
| print("Warning: Loading model in float16 failed due to insufficient memory. Falling back to CPU and float32.") | |
| device = "cpu" # Switch to CPU | |
| model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True).to(device) # Load in float32 on CPU with low CPU mem usage | |
| import gc | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| else: | |
| raise e | |
| processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
| # Pydantic models | |
| class ImageUrl(BaseModel): | |
| url: str | |
| class Image(BaseModel): | |
| type: Literal["image_url"] = "image_url" | |
| image_url: ImageUrl | |
| class Content(BaseModel): | |
| type: Literal["text", "image_url"] | |
| text: Optional[str] = None | |
| image_url: Optional[ImageUrl] = None | |
| class Message(BaseModel): | |
| role: Literal["user", "system", "assistant"] | |
| content: Union[str, List[Content]] | |
| class ChatCompletionRequest(BaseModel): | |
| messages: List[Message] | |
| max_tokens: Optional[int] = 128 | |
| async def chat_completion(request: ChatCompletionRequest = Body(...)): | |
| # Extract first message content | |
| messages = request.messages | |
| max_tokens = request.max_tokens | |
| first_message = messages[0] | |
| image_url = None | |
| text_content = None | |
| if isinstance(first_message.content, str): | |
| text_content = first_message.content | |
| else: | |
| for content_item in first_message.content: | |
| if content_item.type == "image_url": | |
| image_url = content_item.image_url.url | |
| elif content_item.type == "text": | |
| text_content = content_item.text | |
| # Process image if provided | |
| pil_image = None | |
| if image_url: | |
| try: | |
| if image_url.startswith("data:image"): | |
| header, encoded = image_url.split(",", 1) | |
| image_data = base64.b64decode(encoded) | |
| pil_image = PILImage.open(io.BytesIO(image_data)).convert("RGB") | |
| else: | |
| print("Image URL provided, but base64 expected.") | |
| except Exception as e: | |
| print(f"Error processing image: {e}") | |
| raise e | |
| # Generate response | |
| try: | |
| inputs = processor(text=text_content, images=pil_image, return_tensors="pt").to(device) | |
| outputs = model.generate(**inputs, max_new_tokens=max_tokens) | |
| response = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| except Exception as e: | |
| print(f"Error during model inference: {e}") | |
| raise e | |
| return { | |
| "choices": [{ | |
| "message": { | |
| "role": "assistant", | |
| "content": response | |
| } | |
| }] | |
| } | |
| def index(): | |
| return FileResponse("static/index.html") | |
| def startup_event(): | |
| # In Hugging Face Spaces, the application is usually accessible at https://<space_name>.hf.space | |
| # Here we assume the space name is 'api-UI-TARS-7B-DPO' | |
| public_url = "https://api-UI-TARS-7B-DPO.hf.space" | |
| print(f"Public URL: {public_url}") |