|
|
import argparse |
|
|
import os |
|
|
import shutil |
|
|
import subprocess |
|
|
import threading |
|
|
from pathlib import Path |
|
|
from queue import Empty, Queue |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms.functional as TF |
|
|
from PIL import Image |
|
|
from torch.utils.data import DataLoader, IterableDataset |
|
|
from tqdm import tqdm |
|
|
|
|
|
from demo import create_rendering_image |
|
|
from sheap import load_sheap_model |
|
|
from sheap.tiny_flame import TinyFlame, pose_components_to_rotmats |
|
|
|
|
|
try: |
|
|
import face_alignment |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"The 'face_alignment' package is required. Please install it via 'pip install face-alignment'." |
|
|
) |
|
|
from sheap.fa_landmark_utils import detect_face_and_crop |
|
|
|
|
|
|
|
|
class RenderingThread(threading.Thread): |
|
|
"""Background thread for rendering frames to images.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
render_queue: Queue, |
|
|
temp_dir: Path, |
|
|
faces: torch.Tensor, |
|
|
c2w: torch.Tensor, |
|
|
render_size: int, |
|
|
): |
|
|
""" |
|
|
Initialize rendering thread. |
|
|
|
|
|
Args: |
|
|
render_queue: Queue containing (frame_idx, cropped_frame, verts) tuples |
|
|
temp_dir: Directory to save rendered images |
|
|
faces: Face indices tensor from FLAME model |
|
|
c2w: Camera-to-world transformation matrix |
|
|
render_size: Size of each sub-image in the rendered output |
|
|
""" |
|
|
super().__init__(daemon=True) |
|
|
self.render_queue = render_queue |
|
|
self.temp_dir = temp_dir |
|
|
self.faces = faces |
|
|
self.c2w = c2w |
|
|
self.render_size = render_size |
|
|
self.stop_event = threading.Event() |
|
|
self.frames_rendered = 0 |
|
|
|
|
|
def run(self): |
|
|
"""Process rendering queue until stop signal is received.""" |
|
|
|
|
|
os.environ["PYOPENGL_PLATFORM"] = "egl" |
|
|
|
|
|
while not self.stop_event.is_set(): |
|
|
try: |
|
|
|
|
|
try: |
|
|
item = self.render_queue.get(timeout=0.1) |
|
|
except Empty: |
|
|
continue |
|
|
if item is None: |
|
|
break |
|
|
|
|
|
frame_idx, cropped_frame, verts = item |
|
|
frame_idx, cropped_frame, verts = item |
|
|
|
|
|
|
|
|
cropped_pil = Image.fromarray(cropped_frame) |
|
|
combined = create_rendering_image( |
|
|
original_image=cropped_pil, |
|
|
verts=verts, |
|
|
faces=self.faces, |
|
|
c2w=self.c2w, |
|
|
output_size=self.render_size, |
|
|
) |
|
|
|
|
|
|
|
|
output_path = self.temp_dir / f"frame_{frame_idx:06d}.png" |
|
|
combined.save(output_path) |
|
|
|
|
|
self.frames_rendered += 1 |
|
|
self.render_queue.task_done() |
|
|
|
|
|
except Exception as e: |
|
|
if not self.stop_event.is_set(): |
|
|
print(f"Error rendering frame: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
|
|
|
def stop(self): |
|
|
"""Signal the thread to stop.""" |
|
|
self.stop_event.set() |
|
|
|
|
|
|
|
|
class VideoFrameDataset(IterableDataset): |
|
|
"""Iterable dataset for streaming video frames with face detection and cropping. |
|
|
|
|
|
Uses a background thread for video frame loading while face detection runs in the main thread. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
video_path: str, |
|
|
fa_model: face_alignment.FaceAlignment, |
|
|
smoothing_alpha: float = 0.3, |
|
|
frame_buffer_size: int = 32, |
|
|
): |
|
|
""" |
|
|
Initialize video frame dataset. |
|
|
|
|
|
Args: |
|
|
video_path: Path to video file |
|
|
fa_model: FaceAlignment model instance for face detection |
|
|
smoothing_alpha: Smoothing factor for bounding box (0=no smoothing, 1=no change). |
|
|
Lower values = more smoothing |
|
|
frame_buffer_size: Size of the frame buffer queue for the background thread |
|
|
""" |
|
|
super().__init__() |
|
|
self.video_path = video_path |
|
|
self.fa_model = fa_model |
|
|
self.smoothing_alpha = smoothing_alpha |
|
|
self.frame_buffer_size = frame_buffer_size |
|
|
self.prev_bbox: Optional[Tuple[int, int, int, int]] = None |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
raise ValueError(f"Could not open video file: {video_path}") |
|
|
|
|
|
self.fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
self.num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
cap.release() |
|
|
|
|
|
print( |
|
|
f"Video info: {self.num_frames} frames, {self.fps:.2f} fps, {self.width}x{self.height}" |
|
|
) |
|
|
|
|
|
def _video_reader_thread(self, frame_queue: Queue, stop_event: threading.Event): |
|
|
"""Background thread that reads video frames and puts them in a queue. |
|
|
|
|
|
Args: |
|
|
frame_queue: Queue to put (frame_idx, frame_rgb) tuples |
|
|
stop_event: Event to signal thread to stop |
|
|
""" |
|
|
cap = cv2.VideoCapture(self.video_path) |
|
|
if not cap.isOpened(): |
|
|
frame_queue.put(("error", f"Could not open video file: {self.video_path}")) |
|
|
return |
|
|
|
|
|
frame_idx = 0 |
|
|
try: |
|
|
while not stop_event.is_set(): |
|
|
ret, frame_bgr = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
frame_queue.put((frame_idx, frame_rgb)) |
|
|
frame_idx += 1 |
|
|
|
|
|
finally: |
|
|
cap.release() |
|
|
|
|
|
frame_queue.put(None) |
|
|
|
|
|
def __iter__(self): |
|
|
""" |
|
|
Iterate through video frames sequentially. |
|
|
|
|
|
Video frame loading happens in a background thread, while face detection |
|
|
and processing happen in the main thread. |
|
|
|
|
|
Yields: |
|
|
Dictionary containing frame_idx, processed image, and bounding box |
|
|
""" |
|
|
|
|
|
self.prev_bbox = None |
|
|
|
|
|
|
|
|
frame_queue = Queue(maxsize=self.frame_buffer_size) |
|
|
stop_event = threading.Event() |
|
|
reader_thread = threading.Thread( |
|
|
target=self._video_reader_thread, |
|
|
args=(frame_queue, stop_event), |
|
|
daemon=True |
|
|
) |
|
|
reader_thread.start() |
|
|
|
|
|
try: |
|
|
while True: |
|
|
|
|
|
item = frame_queue.get() |
|
|
|
|
|
|
|
|
if item is None: |
|
|
break |
|
|
|
|
|
|
|
|
if isinstance(item, tuple) and len(item) == 2 and item[0] == "error": |
|
|
raise RuntimeError(item[1]) |
|
|
|
|
|
frame_idx, frame_rgb = item |
|
|
|
|
|
|
|
|
image = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0 |
|
|
|
|
|
|
|
|
bbox = detect_face_and_crop(image, self.fa_model, margin=0.9, shift_up=0.5) |
|
|
|
|
|
|
|
|
bbox = self._smooth_bbox(bbox) |
|
|
x0, y0, x1, y1 = bbox |
|
|
|
|
|
cropped = image[:, y0:y1, x0:x1] |
|
|
|
|
|
|
|
|
cropped_resized = TF.resize(cropped, [224, 224], antialias=True) |
|
|
cropped_for_render = TF.resize(cropped, [512, 512], antialias=True) |
|
|
|
|
|
yield { |
|
|
"frame_idx": frame_idx, |
|
|
"image": cropped_resized, |
|
|
"bbox": bbox, |
|
|
"original_frame": frame_rgb, |
|
|
"cropped_frame": cropped_for_render, |
|
|
} |
|
|
|
|
|
finally: |
|
|
|
|
|
stop_event.set() |
|
|
reader_thread.join(timeout=1.0) |
|
|
|
|
|
def _smooth_bbox(self, bbox: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: |
|
|
"""Apply exponential moving average smoothing to bounding box.""" |
|
|
if self.prev_bbox is None: |
|
|
self.prev_bbox = bbox |
|
|
return bbox |
|
|
|
|
|
x0, y0, x1, y1 = bbox |
|
|
prev_x0, prev_y0, prev_x1, prev_y1 = self.prev_bbox |
|
|
|
|
|
|
|
|
smoothed = ( |
|
|
int(self.smoothing_alpha * x0 + (1 - self.smoothing_alpha) * prev_x0), |
|
|
int(self.smoothing_alpha * y0 + (1 - self.smoothing_alpha) * prev_y0), |
|
|
int(self.smoothing_alpha * x1 + (1 - self.smoothing_alpha) * prev_x1), |
|
|
int(self.smoothing_alpha * y1 + (1 - self.smoothing_alpha) * prev_y1), |
|
|
) |
|
|
|
|
|
self.prev_bbox = smoothed |
|
|
return smoothed |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return self.num_frames |
|
|
|
|
|
|
|
|
def process_video( |
|
|
video_path: str, |
|
|
model_type: str = "expressive", |
|
|
batch_size: int = 1, |
|
|
num_workers: int = 0, |
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu", |
|
|
output_video_path: Optional[str] = None, |
|
|
render_size: int = 512, |
|
|
num_render_workers: int = 1, |
|
|
max_queue_size: int = 128, |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process video frames through SHEAP model and optionally render output video. |
|
|
|
|
|
Uses an IterableDataset for efficient sequential video processing without seeking overhead. |
|
|
Rendering is done in a background thread, and ffmpeg is used to create the final video. |
|
|
|
|
|
Args: |
|
|
video_path: Path to video file |
|
|
model_type: SHEAP model variant ("paper", "expressive", or "lightweight") |
|
|
batch_size: Batch size for processing |
|
|
num_workers: Number of workers (0 or 1 only). Will be clamped to max 1. |
|
|
device: Device to run model on ("cpu" or "cuda") |
|
|
output_video_path: If provided, render and save output video to this path |
|
|
render_size: Size of each sub-image in the rendered output |
|
|
num_render_workers: Number of background threads for rendering |
|
|
max_queue_size: Maximum size of the rendering queue |
|
|
|
|
|
Returns: |
|
|
List of dictionaries containing frame index, bounding box, and FLAME parameters |
|
|
""" |
|
|
|
|
|
num_workers = min(num_workers, 1) |
|
|
if num_workers > 1: |
|
|
print(f"Warning: num_workers > 1 not supported with IterableDataset. Using num_workers=1.") |
|
|
|
|
|
|
|
|
print(f"Loading SHEAP model (type: {model_type})...") |
|
|
sheap_model = load_sheap_model(model_type=model_type) |
|
|
sheap_model.eval() |
|
|
sheap_model = sheap_model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
fa_device = "cpu" if num_workers >= 1 else device |
|
|
print(f"Loading face alignment model on {fa_device}...") |
|
|
fa_model = face_alignment.FaceAlignment( |
|
|
face_alignment.LandmarksType.THREE_D, flip_input=False, device=fa_device |
|
|
) |
|
|
|
|
|
|
|
|
dataset = VideoFrameDataset(video_path, fa_model) |
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers, |
|
|
pin_memory=torch.cuda.is_available(), |
|
|
) |
|
|
|
|
|
print(f"Processing {len(dataset)} frames from {video_path}") |
|
|
|
|
|
|
|
|
flame = None |
|
|
rendering_threads = [] |
|
|
render_queue = None |
|
|
temp_dir = None |
|
|
c2w = None |
|
|
|
|
|
if output_video_path: |
|
|
print("Loading FLAME model for rendering...") |
|
|
flame_dir = Path("FLAME2020/") |
|
|
flame = TinyFlame(flame_dir / "generic_model.pt", eyelids_ckpt=flame_dir / "eyelids.pt") |
|
|
flame = flame.to(device) |
|
|
c2w = torch.tensor( |
|
|
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
temp_dir = Path("./temp_sheap_render/") |
|
|
temp_dir.mkdir(parents=True, exist_ok=True) |
|
|
print(f"Using temporary directory: {temp_dir}") |
|
|
|
|
|
|
|
|
render_queue = Queue(maxsize=max_queue_size) |
|
|
for _ in range(num_render_workers): |
|
|
thread = RenderingThread(render_queue, temp_dir, flame.faces, c2w, render_size) |
|
|
thread.start() |
|
|
rendering_threads.append(thread) |
|
|
print(f"Started {num_render_workers} background rendering threads") |
|
|
|
|
|
results = [] |
|
|
frame_count = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
progbar = tqdm(total=len(dataset), desc="Processing frames") |
|
|
for batch in dataloader: |
|
|
frame_indices = batch["frame_idx"] |
|
|
images = batch["image"].to(device) |
|
|
bboxes = batch["bbox"] |
|
|
|
|
|
|
|
|
flame_params_dict = sheap_model(images) |
|
|
|
|
|
|
|
|
if output_video_path and flame is not None: |
|
|
verts = flame( |
|
|
shape=flame_params_dict["shape_from_facenet"], |
|
|
expression=flame_params_dict["expr"], |
|
|
pose=pose_components_to_rotmats(flame_params_dict), |
|
|
eyelids=flame_params_dict["eyelids"], |
|
|
translation=flame_params_dict["cam_trans"], |
|
|
) |
|
|
|
|
|
|
|
|
for i in range(len(frame_indices)): |
|
|
frame_idx = _extract_scalar(frame_indices[i]) |
|
|
bbox = tuple(_extract_scalar(b[i]) for b in bboxes) |
|
|
|
|
|
result = { |
|
|
"frame_idx": frame_idx, |
|
|
"bbox": bbox, |
|
|
"flame_params": {k: v[i].cpu() for k, v in flame_params_dict.items()}, |
|
|
} |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
if output_video_path: |
|
|
cropped_frame = _tensor_to_numpy_image(batch["cropped_frame"][i]) |
|
|
render_queue.put((frame_idx, cropped_frame, verts[i].cpu())) |
|
|
frame_count += 1 |
|
|
|
|
|
progbar.update(len(frame_indices)) |
|
|
progbar.close() |
|
|
|
|
|
|
|
|
if output_video_path and render_queue is not None: |
|
|
_finalize_rendering( |
|
|
rendering_threads, |
|
|
render_queue, |
|
|
num_render_workers, |
|
|
temp_dir, |
|
|
dataset.fps, |
|
|
output_video_path, |
|
|
) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def _extract_scalar(value: Any) -> int: |
|
|
"""Extract scalar integer from tensor or return as-is.""" |
|
|
return value.item() if isinstance(value, torch.Tensor) else value |
|
|
|
|
|
|
|
|
def _tensor_to_numpy_image(tensor: torch.Tensor) -> np.ndarray: |
|
|
"""Convert (C, H, W) tensor [0, 1] to numpy (H, W, C) uint8 [0, 255].""" |
|
|
if not isinstance(tensor, torch.Tensor): |
|
|
return tensor |
|
|
return (tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
def _finalize_rendering( |
|
|
rendering_threads: List[RenderingThread], |
|
|
render_queue: Queue, |
|
|
num_render_workers: int, |
|
|
temp_dir: Path, |
|
|
fps: float, |
|
|
output_video_path: str, |
|
|
) -> None: |
|
|
"""Finish rendering threads and create final video with ffmpeg.""" |
|
|
print("\nWaiting for rendering threads to complete...") |
|
|
|
|
|
|
|
|
for _ in range(num_render_workers): |
|
|
render_queue.put(None) |
|
|
|
|
|
|
|
|
for thread in rendering_threads: |
|
|
thread.join() |
|
|
|
|
|
total_rendered = sum(thread.frames_rendered for thread in rendering_threads) |
|
|
print(f"Rendered {total_rendered} frames") |
|
|
|
|
|
|
|
|
print("Creating video with ffmpeg...") |
|
|
output_path = Path(output_video_path) |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
ffmpeg_cmd = [ |
|
|
"ffmpeg", |
|
|
"-y", |
|
|
"-framerate", |
|
|
str(fps), |
|
|
"-i", |
|
|
str(temp_dir / "frame_%06d.png"), |
|
|
"-c:v", |
|
|
"libx264", |
|
|
"-pix_fmt", |
|
|
"yuv420p", |
|
|
"-preset", |
|
|
"medium", |
|
|
"-crf", |
|
|
"23", |
|
|
str(output_path), |
|
|
] |
|
|
|
|
|
subprocess.run(ffmpeg_cmd, check=True, capture_output=True) |
|
|
print(f"Video saved to: {output_video_path}") |
|
|
|
|
|
|
|
|
if temp_dir.exists(): |
|
|
print(f"Removing temporary directory: {temp_dir}") |
|
|
shutil.rmtree(temp_dir) |
|
|
print("Cleanup complete") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Process and render video with SHEAP model.") |
|
|
parser.add_argument("in_path", type=str, help="Path to input video file.") |
|
|
parser.add_argument( |
|
|
"--out_path", type=str, help="Path to save rendered output video.", default=None |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.out_path is None: |
|
|
args.out_path = str(Path(args.in_path).with_name(f"{Path(args.in_path).stem}_rendered.mp4")) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
results = process_video( |
|
|
video_path=args.in_path, |
|
|
model_type="expressive", |
|
|
device=device, |
|
|
output_video_path=args.out_path, |
|
|
) |
|
|
|