|
|
""" |
|
|
Gradio demo for SHeaP (Self-Supervised Head Geometry Predictor). |
|
|
Accepts video or image input and renders the SHEAP output overlayed. |
|
|
""" |
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "PYOPENGL_PLATFORM" not in os.environ: |
|
|
|
|
|
os.environ["PYOPENGL_PLATFORM"] = "egl" |
|
|
|
|
|
import shutil |
|
|
import subprocess |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from queue import Queue |
|
|
from typing import Optional |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.hub |
|
|
import torchvision.transforms.functional as TF |
|
|
from PIL import Image |
|
|
from torch.utils.data import DataLoader |
|
|
import face_alignment |
|
|
|
|
|
try: |
|
|
import spaces |
|
|
HAS_SPACES = True |
|
|
except ImportError: |
|
|
HAS_SPACES = False |
|
|
|
|
|
class spaces: |
|
|
@staticmethod |
|
|
def GPU(func): |
|
|
return func |
|
|
|
|
|
from demo import create_rendering_image |
|
|
from sheap import load_sheap_model |
|
|
from sheap.tiny_flame import TinyFlame, pose_components_to_rotmats |
|
|
from video_demo import RenderingThread, VideoFrameDataset, _tensor_to_numpy_image |
|
|
from sheap.fa_landmark_utils import detect_face_and_crop |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
sheap_model = None |
|
|
flame = None |
|
|
fa_model = None |
|
|
c2w = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.float32) |
|
|
|
|
|
|
|
|
def initialize_models(): |
|
|
"""Initialize all models (called lazily on first use).""" |
|
|
global sheap_model, flame, fa_model |
|
|
|
|
|
if sheap_model is not None: |
|
|
return |
|
|
|
|
|
print("Loading SHeaP model...", flush=True) |
|
|
sheap_model = load_sheap_model(model_type="expressive").to(device) |
|
|
sheap_model.eval() |
|
|
|
|
|
print("Loading FLAME model...", flush=True) |
|
|
flame_dir = Path("FLAME2020/") |
|
|
flame = TinyFlame(flame_dir / "generic_model.pt", eyelids_ckpt=flame_dir / "eyelids.pt").to( |
|
|
device |
|
|
) |
|
|
|
|
|
print("Loading face alignment model...", flush=True) |
|
|
|
|
|
torch.hub.set_dir(str(Path(__file__).parent / "face_alignment_cache")) |
|
|
fa_model = face_alignment.FaceAlignment( |
|
|
face_alignment.LandmarksType.TWO_D, device=str(device), flip_input=False |
|
|
) |
|
|
|
|
|
print("Models loaded successfully!", flush=True) |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def process_image(image: np.ndarray) -> Image.Image: |
|
|
""" |
|
|
Process a single image and return the rendered output. |
|
|
|
|
|
Args: |
|
|
image: Input image as numpy array (RGB) |
|
|
|
|
|
Returns: |
|
|
PIL Image with three views side-by-side (original, mesh, blended) |
|
|
""" |
|
|
|
|
|
initialize_models() |
|
|
|
|
|
|
|
|
image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 |
|
|
|
|
|
|
|
|
x0, y0, x1, y1 = detect_face_and_crop(image_tensor, fa_model, margin=0.9, shift_up=0.5) |
|
|
|
|
|
|
|
|
cropped_tensor = image_tensor[:, y0:y1, x0:x1] |
|
|
|
|
|
|
|
|
cropped_resized = TF.resize(cropped_tensor, [224, 224], antialias=True) |
|
|
|
|
|
|
|
|
img_tensor = cropped_resized.unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
cropped_for_render = TF.resize(cropped_tensor, [512, 512], antialias=True) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
predictions = sheap_model(img_tensor) |
|
|
|
|
|
|
|
|
verts = flame( |
|
|
shape=predictions["shape_from_facenet"], |
|
|
expression=predictions["expr"], |
|
|
pose=pose_components_to_rotmats(predictions), |
|
|
eyelids=predictions["eyelids"], |
|
|
translation=predictions["cam_trans"], |
|
|
) |
|
|
|
|
|
|
|
|
verts = verts.cpu() |
|
|
|
|
|
|
|
|
cropped_pil = TF.to_pil_image(cropped_for_render) |
|
|
|
|
|
|
|
|
combined = create_rendering_image( |
|
|
original_image=cropped_pil, |
|
|
verts=verts[0], |
|
|
faces=flame.faces, |
|
|
c2w=c2w, |
|
|
output_size=512, |
|
|
) |
|
|
|
|
|
return combined |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def process_video_frames(video_path: str, temp_dir: Path, progress=gr.Progress()): |
|
|
""" |
|
|
Process video frames with GPU (inference and rendering). |
|
|
Returns fps and number of frames processed. |
|
|
""" |
|
|
|
|
|
initialize_models() |
|
|
|
|
|
render_size = 512 |
|
|
|
|
|
dataset = VideoFrameDataset(video_path, fa_model) |
|
|
dataloader = DataLoader(dataset, batch_size=8, num_workers=0) |
|
|
fps = dataset.fps |
|
|
num_frames = len(dataset) |
|
|
|
|
|
render_queue = Queue(maxsize=32) |
|
|
num_render_workers = 4 |
|
|
rendering_threads = [] |
|
|
for _ in range(num_render_workers): |
|
|
thread = RenderingThread(render_queue, temp_dir, flame.faces, c2w, render_size) |
|
|
thread.start() |
|
|
rendering_threads.append(thread) |
|
|
progress(0, desc="Processing video frames...") |
|
|
frame_idx = 0 |
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
images = batch["image"].to(device) |
|
|
cropped_frames = batch["cropped_frame"] |
|
|
|
|
|
predictions = sheap_model(images) |
|
|
verts = flame( |
|
|
shape=predictions["shape_from_facenet"], |
|
|
expression=predictions["expr"], |
|
|
pose=pose_components_to_rotmats(predictions), |
|
|
eyelids=predictions["eyelids"], |
|
|
translation=predictions["cam_trans"], |
|
|
) |
|
|
verts = verts.cpu() |
|
|
for i in range(images.shape[0]): |
|
|
cropped_frame = _tensor_to_numpy_image(cropped_frames[i]) |
|
|
render_queue.put((frame_idx, cropped_frame, verts[i])) |
|
|
frame_idx += 1 |
|
|
progress( |
|
|
0.95 * frame_idx / num_frames, desc=f"Processing frame {frame_idx}/{num_frames}" |
|
|
) |
|
|
|
|
|
for _ in range(num_render_workers): |
|
|
render_queue.put(None) |
|
|
for thread in rendering_threads: |
|
|
thread.join() |
|
|
if frame_idx == 0: |
|
|
raise ValueError("No frames were successfully processed!") |
|
|
|
|
|
return fps, frame_idx |
|
|
|
|
|
|
|
|
def process_video(video_path: str, progress=gr.Progress()) -> str: |
|
|
""" |
|
|
Process a video and return path to the rendered output video. |
|
|
""" |
|
|
temp_dir = Path(tempfile.mkdtemp()) |
|
|
try: |
|
|
|
|
|
fps, num_frames = process_video_frames(video_path, temp_dir, progress) |
|
|
|
|
|
|
|
|
progress(0.95, desc="Encoding video...") |
|
|
output_path = temp_dir / "output.mp4" |
|
|
ffmpeg_cmd = [ |
|
|
"ffmpeg", |
|
|
"-y", |
|
|
"-framerate", |
|
|
str(fps), |
|
|
"-i", |
|
|
str(temp_dir / "frame_%06d.png"), |
|
|
"-c:v", |
|
|
"libx264", |
|
|
"-pix_fmt", |
|
|
"yuv420p", |
|
|
"-crf", |
|
|
"18", |
|
|
str(output_path), |
|
|
] |
|
|
subprocess.run(ffmpeg_cmd, check=True, capture_output=True) |
|
|
progress(1.0, desc="Done!") |
|
|
return str(output_path) |
|
|
except Exception as e: |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
raise e |
|
|
|
|
|
|
|
|
def process_input(image: Optional[np.ndarray], video: Optional[str]): |
|
|
""" |
|
|
Process either image or video input. |
|
|
|
|
|
Args: |
|
|
image: Input image (if provided) |
|
|
video: Input video path (if provided) |
|
|
|
|
|
Returns: |
|
|
Either an image or video path depending on input |
|
|
""" |
|
|
if image is not None: |
|
|
return process_image(image), None |
|
|
elif video is not None: |
|
|
return None, process_video(video) |
|
|
else: |
|
|
raise ValueError("Please provide either an image or video!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="SHeaP Demo") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π SHeaP: Self-Supervised Head Geometry Predictor π |
|
|
|
|
|
Upload an image or video to predict head geometry and render a 3D mesh overlay! |
|
|
|
|
|
The output shows three views: |
|
|
- **Left**: Original cropped face |
|
|
- **Center**: Rendered FLAME mesh |
|
|
- **Right**: Mesh overlaid on original |
|
|
|
|
|
[Project Page](https://nlml.github.io/sheap) | [Paper](https://arxiv.org/abs/2504.12292) | [GitHub](https://github.com/nlml/sheap) |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Input") |
|
|
image_input = gr.Image(label="Upload Image", type="numpy") |
|
|
video_input = gr.Video(label="Upload Video") |
|
|
process_btn = gr.Button("Process", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### Output") |
|
|
image_output = gr.Image(label="Rendered Image", type="pil") |
|
|
video_output = gr.Video(label="Rendered Video") |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### Tips: |
|
|
- For best results, use images/videos with clearly visible faces |
|
|
- The model works best with frontal face views |
|
|
- Video processing may take a few minutes depending on length |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
fn=process_input, |
|
|
inputs=[image_input, video_input], |
|
|
outputs=[image_output, video_output], |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["example_images/00000206.jpg", None], |
|
|
[None, "example_videos/dafoe.mp4"], |
|
|
], |
|
|
inputs=[image_input, video_input], |
|
|
outputs=[image_output, video_output], |
|
|
fn=process_input, |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|