diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..d7b47cb87c6012f598dbfbbf7c56d76e1cc12508 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,49 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/gui.png filter=lfs diff=lfs merge=lfs -text
+assets/teaser.gif filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_Bridge/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_Bridge/mask.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_Bridge/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_ConcertCrowd/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_ConcertCrowd/mask.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_ConcertCrowd/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_ConcertStage/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_ConcertStage/mask.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_ConcertStage/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_RiverOcean/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_RiverOcean/mask.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_RiverOcean/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_SpiderMan/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_SpiderMan/mask.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_SpiderMan/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_Volcano/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_Volcano/mask.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_Volcano/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_VolcanoTitan/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_VolcanoTitan/mask.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/camcontrol_VolcanoTitan/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_cog_Monkey/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_cog_Monkey/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_svd_Fish/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_svd_Fish/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Birds/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Birds/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Cocktail/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Cocktail/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Gardening/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Gardening/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Hamburger/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Hamburger/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Jumping/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Monkey/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Monkey/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Owl/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Owl/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Rhino/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Rhino/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Surfing/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_Surfing/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_TimeSquares/first_frame.png filter=lfs diff=lfs merge=lfs -text
+examples/cutdrag_wan_TimeSquares/motion_signal.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/GUIs/README.md b/GUIs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6e63293cb92c0ac5e0df4ebca35d24b3fd560187
--- /dev/null
+++ b/GUIs/README.md
@@ -0,0 +1,60 @@
+# Cut and Drag GUI
+
+We provide a GUI to generate cut-and-drag examples that will later be used for video generation given this input signal using **Time to Move**!
+Given an input frame, you can cut and drag polygons from the initial image, transform their colors, and also add external images that can be dragged into the initial scene.
+
+## ✨ General Guide
+- Select an initial image.
+- Draw polygons in the image and drag them in several segments.
+- During segments you can rotate, scale, and change the polygon colors!
+- Polygons can also be dragged from an external image you can add into the scene.
+- Write an text prompt that will be used to generate the video afterwards.
+- You can preview the motion signal in an in-app demo.
+- In the end, all the inputs needed for **Time-to-Move** are saved automatically in a selected output directory.
+
+## 🧰 Requirements
+Install dependencies:
+```bash
+pip install PySide6 opencv-python numpy imageio imageio-ffmpeg
+```
+
+## 🚀 Run
+Just run the python script:
+```bash
+python cut_and_drag.py
+```
+
+## 🖱️ How to Use
+* Select Image — Click 🖼️ Select Image and choose an image.
+ * Choose Center Crop / Center Pad at the top of the toolbar if needed.
+* Add a Polygon “cutting” the part of the image by clicking Add Polygon.
+ * Left-click to add points.
+ * After finishing drawing the polygon, press ✅ Finish Polygon Selection.
+* Drag to move the polygon
+ * During segments you’ll see corner circles and a top dot which can be used for scaling and rotating during the segments; in the video the shape is interpolated between the initial frame status and the final segment one.
+ * Also, color transformation can be applied (using hue transformation) in the segments to change polygon colors.
+ * Click 🎯 End Segment to capture the segment annotated.
+ * The movement trajectory can be constructed from multiple segments: repeat move → 🎯 End Segment → move → 🎯 End Segment…
+* External Image
+ * Another option is to add an external image to the scene.
+ * Click 🖼️➕ Add External Image, pick a new image.
+ * Position/scale/rotate it for its initial pose, then click ✅ Place External Image to lock its starting pose.
+ * Now animate it like before: mark a polygon, move, etc.
+* Prompt
+ * Type any text prompt you want associated with this example; it will be used later for video generation with our method.
+* Preview and Save
+ * Preview using ▶️ Play Demo.
+ * Click 💾 Save, choose an output folder and then enter a subfolder name.
+ * Click 🆕 New to start a new project.
+
+## Output Files
+* first_frame.png — the initial frame for video generation
+* motion_signal.mp4 — the reference warped video
+* mask.mp4 — grayscale mask of the motion
+* prompt.txt — your prompt text
+
+
+## 🧾 License / Credits
+Built with PySide6, OpenCV, and NumPy.
+You own the images and exports you create with this tool.
+Motivation for creating an easy-to-use tool from [Go-With-The-Flow](https://github.com/GoWithTheFlowPaper/gowiththeflowpaper.github.io).
\ No newline at end of file
diff --git a/GUIs/cut_and_drag.py b/GUIs/cut_and_drag.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e719bbee6d2425e44ee9c5edfa1ac9aaa7fa9aa
--- /dev/null
+++ b/GUIs/cut_and_drag.py
@@ -0,0 +1,1904 @@
+# Copyright 2025 Noam Rotstein
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os, sys, cv2, numpy as np
+from dataclasses import dataclass, field
+from typing import List, Optional, Tuple, Dict
+import shutil, subprocess
+
+
+from PySide6 import QtCore, QtGui, QtWidgets
+from PySide6.QtCore import Qt, QPointF, Signal
+from PySide6.QtGui import (
+ QImage, QPixmap, QPainterPath, QPen, QColor, QPainter, QPolygonF, QIcon
+)
+from PySide6.QtWidgets import (
+ QApplication, QMainWindow, QFileDialog, QGraphicsView, QGraphicsScene,
+ QGraphicsPixmapItem, QGraphicsPathItem, QGraphicsLineItem,
+ QToolBar, QLabel, QSpinBox, QWidget, QMessageBox,
+ QComboBox, QPushButton, QGraphicsEllipseItem, QFrame, QVBoxLayout, QSlider, QHBoxLayout,
+ QPlainTextEdit
+)
+import imageio
+
+# ------------------------------
+# Utility: numpy <-> QPixmap
+# ------------------------------
+
+def np_bgr_to_qpixmap(img_bgr: np.ndarray) -> QPixmap:
+ h, w = img_bgr.shape[:2]
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+ qimg = QImage(img_rgb.data, w, h, img_rgb.strides[0], QImage.Format.Format_RGB888)
+ return QPixmap.fromImage(qimg.copy())
+
+def np_rgba_to_qpixmap(img_rgba: np.ndarray) -> QPixmap:
+ h, w = img_rgba.shape[:2]
+ qimg = QImage(img_rgba.data, w, h, img_rgba.strides[0], QImage.Format.Format_RGBA8888)
+ return QPixmap.fromImage(qimg.copy())
+
+# ------------------------------
+# Image I/O + fit helpers
+# ------------------------------
+
+def load_first_frame(path: str) -> np.ndarray:
+ if not os.path.exists(path):
+ raise FileNotFoundError(path)
+ low = path.lower()
+ if low.endswith((".mp4", ".mov", ".avi", ".mkv")):
+ cap = cv2.VideoCapture(path)
+ ok, frame = cap.read()
+ cap.release()
+ if not ok:
+ raise RuntimeError("Failed to read first frame from video")
+ return frame
+ img = cv2.imread(path, cv2.IMREAD_COLOR)
+ if img is None:
+ raise RuntimeError("Failed to read image")
+ return img
+
+def resize_then_center_crop(img: np.ndarray, target_h: int, target_w: int, interpolation=cv2.INTER_NEAREST) -> np.ndarray:
+ h, w = img.shape[:2]
+ scale = max(target_w / float(w), target_h / float(h))
+ new_w, new_h = int(round(w * scale)), int(round(h * scale))
+ resized = cv2.resize(img, (new_w, new_h), interpolation=interpolation)
+ y0 = (new_h - target_h) // 2
+ x0 = (new_w - target_w) // 2
+ return resized[y0:y0 + target_h, x0:x0 + target_w]
+
+def fit_center_pad(img: np.ndarray, target_h: int, target_w: int, interpolation=cv2.INTER_NEAREST) -> np.ndarray:
+ h, w = img.shape[:2]
+ scale_h = target_h / float(h)
+ new_w_hfirst = int(round(w * scale_h))
+ new_h_hfirst = target_h
+ if new_w_hfirst <= target_w:
+ resized = cv2.resize(img, (new_w_hfirst, new_h_hfirst), interpolation=interpolation)
+ result = np.zeros((target_h, target_w, 3), dtype=np.uint8)
+ x0 = (target_w - new_w_hfirst) // 2
+ result[:, x0:x0 + new_w_hfirst] = resized
+ return result
+ scale_w = target_w / float(w)
+ new_w_wfirst = target_w
+ new_h_wfirst = int(round(h * scale_w))
+ resized = cv2.resize(img, (new_w_wfirst, new_h_wfirst), interpolation=interpolation)
+ result = np.zeros((target_h, target_w, 3), dtype=np.uint8)
+ y0 = (target_h - new_h_wfirst) // 2
+ result[y0:y0 + new_h_wfirst, :] = resized
+ return result
+
+# ------------------------------
+# Hue utilities
+# ------------------------------
+
+def apply_hue_shift_bgr(img_bgr: np.ndarray, hue_deg: float) -> np.ndarray:
+ """Rotate hue by hue_deg (degrees) in HSV space. S and V unchanged."""
+ if abs(hue_deg) < 1e-6:
+ return img_bgr.copy()
+ hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
+ h = hsv[:, :, 0].astype(np.int16)
+ offset = int(round((hue_deg / 360.0) * 179.0))
+ h = (h + offset) % 180
+ hsv[:, :, 0] = h.astype(np.uint8)
+ return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
+
+# ------------------------------
+# Compositing / warping (final)
+# ------------------------------
+
+def alpha_over(bg_bgr: np.ndarray, fg_rgba: np.ndarray) -> np.ndarray:
+ a = (fg_rgba[:, :, 3:4].astype(np.float32) / 255.0)
+ if a.max() == 0:
+ return bg_bgr.copy()
+ fg = fg_rgba[:, :, :3].astype(np.float32)
+ bg = bg_bgr.astype(np.float32)
+ out = fg * a + bg * (1.0 - a)
+ return np.clip(out, 0, 255).astype(np.uint8)
+
+def inpaint_background(image_bgr: np.ndarray, mask_bool: np.ndarray) -> np.ndarray:
+ mask = (mask_bool.astype(np.uint8) * 255)
+ return cv2.inpaint(image_bgr, mask, 3, cv2.INPAINT_TELEA)
+
+def animate_polygon(image_bgr, polygon_xy, path_xy, scales, rotations_deg, interp=cv2.INTER_LINEAR, origin_xy=None):
+ """
+ Returns list of RGBA frames and list of transformed polygons per frame.
+ Uses BORDER_REPLICATE so off-canvas doesn't appear black.
+ """
+ h, w = image_bgr.shape[:2]
+ frames_rgba = []
+ polys_per_frame = []
+
+ if origin_xy is None:
+ if len(path_xy) == 0:
+ raise ValueError("animate_polygon: path_xy is empty and origin_xy not provided.")
+ origin = np.asarray(path_xy[0], dtype=np.float32)
+ else:
+ origin = np.asarray(origin_xy, dtype=np.float32)
+
+ for i in range(len(path_xy)):
+ theta = np.deg2rad(rotations_deg[i]).astype(np.float32)
+ s = float(scales[i])
+ a11 = s * np.cos(theta); a12 = -s * np.sin(theta)
+ a21 = s * np.sin(theta); a22 = s * np.cos(theta)
+ tx = path_xy[i, 0] - (a11 * origin[0] + a12 * origin[1])
+ ty = path_xy[i, 1] - (a21 * origin[0] + a22 * origin[1])
+ M = np.array([[a11, a12, tx], [a21, a22, ty]], dtype=np.float32)
+
+ warped = cv2.warpAffine(image_bgr, M, (w, h), flags=interp,
+ borderMode=cv2.BORDER_REPLICATE)
+
+ poly = np.asarray(polygon_xy, dtype=np.float32)
+ pts1 = np.hstack([poly, np.ones((len(poly), 1), dtype=np.float32)])
+ poly_t = (M @ pts1.T).T
+ polys_per_frame.append(poly_t.astype(np.float32))
+
+ mask = np.zeros((h, w), dtype=np.uint8)
+ cv2.fillPoly(mask, [poly_t.astype(np.int32)], 255)
+
+ rgba = np.zeros((h, w, 4), dtype=np.uint8)
+ rgba[:, :, :3] = warped
+ rgba[:, :, 3] = mask
+ frames_rgba.append(rgba)
+
+ return frames_rgba, polys_per_frame
+
+def composite_frames(background_bgr, list_of_layer_frame_lists):
+ frames = []
+ T = len(list_of_layer_frame_lists[0]) if list_of_layer_frame_lists else 0
+ for t in range(T):
+ frame = background_bgr.copy()
+ for layer in list_of_layer_frame_lists:
+ frame = alpha_over(frame, layer[t])
+ frames.append(frame)
+ return frames
+
+def save_video_mp4(frames_bgr, path, fps=24):
+ """
+ Write MP4 using imageio (FFmpeg backend) with H.264 + yuv420p so it works on macOS/QuickTime.
+ - Converts BGR->RGB (imageio expects RGB)
+ - Enforces even width/height (needed for yuv420p)
+ - Tags BT.709 and faststart for smooth playback
+ """
+ if not frames_bgr:
+ raise ValueError("No frames to save")
+
+ # Validate and normalize frames (to RGB uint8 and consistent size)
+ h, w = frames_bgr[0].shape[:2]
+ out_frames = []
+ for f in frames_bgr:
+ if f is None:
+ raise RuntimeError("Encountered None frame")
+ # Accept gray/BGR/BGRA; convert to BGR then to RGB
+ if f.ndim == 2:
+ f = cv2.cvtColor(f, cv2.COLOR_GRAY2BGR)
+ elif f.shape[2] == 4:
+ f = cv2.cvtColor(f, cv2.COLOR_BGRA2BGR)
+ elif f.shape[2] != 3:
+ raise RuntimeError("Frames must be gray, BGR, or BGRA")
+ if f.shape[:2] != (h, w):
+ raise RuntimeError("Frame size mismatch during save.")
+ if f.dtype != np.uint8:
+ f = np.clip(f, 0, 255).astype(np.uint8)
+ # BGR -> RGB for imageio/ffmpeg
+ out_frames.append(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
+
+ # Enforce even dims for yuv420p
+ hh = h - (h % 2)
+ ww = w - (w % 2)
+ if (hh != h) or (ww != w):
+ out_frames = [frm[:hh, :ww] for frm in out_frames]
+ h, w = hh, ww
+
+ # Try libx264 first; fall back to MPEG-4 Part 2 if libx264 missing
+ ffmpeg_common = ['-movflags', '+faststart',
+ '-colorspace', 'bt709', '-color_primaries', 'bt709', '-color_trc', 'bt709',
+ '-tag:v', 'avc1'] # helps QuickTime recognize H.264 properly
+ try:
+ writer = imageio.get_writer(
+ path, format='ffmpeg', fps=float(fps),
+ codec='libx264', pixelformat='yuv420p',
+ ffmpeg_params=ffmpeg_common
+ )
+ except Exception:
+ # Fallback: MPEG-4 (still Mac-friendly, a bit larger/softer)
+ writer = imageio.get_writer(
+ path, format='ffmpeg', fps=float(fps),
+ codec='mpeg4', pixelformat='yuv420p',
+ ffmpeg_params=['-movflags', '+faststart']
+ )
+
+ try:
+ for frm in out_frames:
+ writer.append_data(frm)
+ finally:
+ writer.close()
+
+ if not os.path.exists(path) or os.path.getsize(path) == 0:
+ raise RuntimeError("imageio/ffmpeg produced an empty file. Check that FFmpeg is available.")
+ return path
+
+
+
+# ------------------------------
+# Data structures
+# ------------------------------
+
+PALETTE = [
+ QColor(255, 99, 99), # red
+ QColor(99, 155, 255), # blue
+ QColor(120, 220, 120), # green
+ QColor(255, 200, 80), # orange
+ QColor(200, 120, 255), # purple
+ QColor(120, 255, 255) # cyan
+]
+
+@dataclass
+class Keyframe:
+ pos: np.ndarray # (2,)
+ rot_deg: float
+ scale: float
+ hue_deg: float = 0.0
+
+@dataclass
+class Layer:
+ name: str
+ source_bgr: np.ndarray
+ polygon_xy: Optional[np.ndarray] = None
+ origin_local_xy: Optional[np.ndarray] = None # bbox center in item coords
+ is_external: bool = False
+ pixmap_item: Optional[QtWidgets.QGraphicsPixmapItem] = None
+ outline_item: Optional[QGraphicsPathItem] = None
+ handle_items: List[QtWidgets.QGraphicsItem] = field(default_factory=list)
+ keyframes: List[Keyframe] = field(default_factory=list)
+ path_lines: List[QGraphicsLineItem] = field(default_factory=list)
+ preview_line: Optional[QGraphicsLineItem] = None
+ color: QColor = field(default_factory=lambda: QColor(255, 99, 99))
+
+ def has_polygon(self) -> bool:
+ return self.polygon_xy is not None and len(self.polygon_xy) >= 3
+
+# ------------------------------
+# Handles (scale corners + rotate dot)
+# ------------------------------
+
+class HandleBase(QGraphicsEllipseItem):
+ def __init__(self, r: float, color: QColor, parent=None):
+ super().__init__(-r, -r, 2*r, 2*r, parent)
+ self.setBrush(color)
+ pen = QPen(QColor(0, 0, 0), 1)
+ pen.setCosmetic(True) # (optional) keep 1px outline at any zoom
+ self.setPen(pen)
+ self.setFlag(QGraphicsEllipseItem.ItemIsMovable, False)
+ self.setAcceptHoverEvents(True)
+ self.setZValue(2000)
+ self._item: Optional[QGraphicsPixmapItem] = None
+
+ # 👇 The key line: draw in device coords (no scaling with the polygon)
+ self.setFlag(QGraphicsEllipseItem.ItemIgnoresTransformations, True)
+
+ def set_item(self, item: QGraphicsPixmapItem):
+ self._item = item
+
+ def origin_scene(self) -> QPointF:
+ return self._item.mapToScene(self._item.transformOriginPoint())
+
+class ScaleHandle(HandleBase):
+ def mousePressEvent(self, event: QtWidgets.QGraphicsSceneMouseEvent):
+ if not self._item: return super().mousePressEvent(event)
+ self._start_scale = self._item.scale() if self._item.scale() != 0 else 1.0
+ self._origin_scene = self.origin_scene()
+ v0 = event.scenePos() - self._origin_scene
+ self._d0 = max(1e-6, (v0.x()*v0.x() + v0.y()*v0.y())**0.5)
+ event.accept()
+ def mouseMoveEvent(self, event: QtWidgets.QGraphicsSceneMouseEvent):
+ if not self._item: return super().mouseMoveEvent(event)
+ v = event.scenePos() - self._origin_scene
+ d = max(1e-6, (v.x()*v.x() + v.y()*v.y())**0.5)
+ s = float(self._start_scale * (d / self._d0))
+ s = float(np.clip(s, 0.05, 10.0))
+ self._item.setScale(s)
+ event.accept()
+
+class RotateHandle(HandleBase):
+ def mousePressEvent(self, event: QtWidgets.QGraphicsSceneMouseEvent):
+ if not self._item: return super().mousePressEvent(event)
+ self._start_rot = self._item.rotation()
+ self._origin_scene = self.origin_scene()
+ v0 = event.scenePos() - self._origin_scene
+ self._a0 = np.degrees(np.arctan2(v0.y(), v0.x()))
+ event.accept()
+ def mouseMoveEvent(self, event: QtWidgets.QGraphicsSceneMouseEvent):
+ if not self._item: return super().mouseMoveEvent(event)
+ v = event.scenePos() - self._origin_scene
+ a = np.degrees(np.arctan2(v.y(), v.x()))
+ delta = a - self._a0
+ r = self._start_rot + delta
+ if r > 180: r -= 360
+ if r < -180: r += 360
+ self._item.setRotation(r)
+ event.accept()
+
+# ------------------------------
+# Pixmap item that notifies on transform changes
+# ------------------------------
+
+class NotifyingPixmapItem(QGraphicsPixmapItem):
+ def __init__(self, pm: QPixmap, on_change_cb=None):
+ super().__init__(pm)
+ self._on_change_cb = on_change_cb
+ self.setFlag(QGraphicsPixmapItem.ItemSendsGeometryChanges, True)
+ def itemChange(self, change, value):
+ ret = super().itemChange(change, value)
+ if change in (QGraphicsPixmapItem.ItemPositionHasChanged,
+ QGraphicsPixmapItem.ItemRotationHasChanged,
+ QGraphicsPixmapItem.ItemScaleHasChanged):
+ if callable(self._on_change_cb):
+ self._on_change_cb()
+ return ret
+
+# ------------------------------
+# Canvas
+# ------------------------------
+
+class Canvas(QGraphicsView):
+ MODE_IDLE = 0
+ MODE_DRAW_POLY = 1
+
+ polygon_finished = Signal(bool)
+ end_segment_requested = Signal()
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.setRenderHint(QtGui.QPainter.Antialiasing, True)
+ self.setRenderHint(QtGui.QPainter.SmoothPixmapTransform, False) # NN preview
+ self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
+ self.setDragMode(QGraphicsView.NoDrag)
+ self.scene = QGraphicsScene(self)
+ self.setScene(self.scene)
+
+ self.base_bgr = None
+ self.base_preview_bgr = None
+ self.base_item = None
+ self.layers: List[Layer] = []
+ self.current_layer: Optional[Layer] = None
+ self.layer_index = 0
+
+ self.mode = Canvas.MODE_IDLE
+ self.temp_points: List[QPointF] = []
+ self.temp_path_item: Optional[QGraphicsPathItem] = None
+ self.first_click_marker: Optional[QGraphicsEllipseItem] = None
+
+ self.fit_mode_combo = None
+ self.target_w = 720
+ self.target_h = 480
+
+ # hue preview for current segment (degrees)
+ self.current_segment_hue_deg: float = 0.0
+
+ # Demo playback
+ self.play_timer = QtCore.QTimer(self)
+ self.play_timer.timeout.connect(self._on_play_tick)
+ self.play_frames: List[np.ndarray] = []
+ self.play_index = 0
+ self.player_item: Optional[QGraphicsPixmapItem] = None
+
+ self.setMouseTracking(True)
+ self.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded)
+ self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
+
+
+ # ------------ small helpers ------------
+ def _remove_if_in_scene(self, item):
+ if item is None:
+ return
+ try:
+ sc = item.scene()
+ if sc is None:
+ return # already detached or removed
+ if sc is self.scene:
+ self.scene.removeItem(item)
+ else:
+ # If the item belongs to a different scene, remove it from THAT scene.
+ sc.removeItem(item)
+ except Exception:
+ pass
+
+
+ def _apply_pose_from_origin_scene(self, item, origin_scene_qp: QPointF, rot: float, scale: float):
+ item.setRotation(float(rot))
+ item.setScale(float(scale) if scale != 0 else 1.0)
+ new_origin = item.mapToScene(item.transformOriginPoint())
+ d = origin_scene_qp - new_origin
+ item.setPos(item.pos() + d)
+
+ # ------------ Icons ------------
+ def make_pentagon_icon(self) -> QIcon:
+ pm = QPixmap(22, 22)
+ pm.fill(Qt.GlobalColor.transparent)
+ p = QPainter(pm)
+ p.setRenderHint(QPainter.Antialiasing, True)
+ pen = QPen(QColor(40, 40, 40)); pen.setWidth(2)
+ p.setPen(pen)
+ r = 8; cx, cy = 11, 11
+ pts = []
+ for i in range(5):
+ ang = -90 + i * 72
+ rad = np.radians(ang)
+ pts.append(QPointF(cx + r * np.cos(rad), cy + r * np.sin(rad)))
+ p.drawPolygon(QPolygonF(pts))
+ p.end()
+ return QIcon(pm)
+
+ # -------- Fit helpers --------
+ def _apply_fit(self, img: np.ndarray) -> np.ndarray:
+ mode = 'Center Crop'
+ if self.fit_mode_combo is not None:
+ mode = self.fit_mode_combo.currentText()
+ if mode == 'Center Pad':
+ return fit_center_pad(img, self.target_h, self.target_w, interpolation=cv2.INTER_NEAREST)
+ else:
+ return resize_then_center_crop(img, self.target_h, self.target_w, interpolation=cv2.INTER_NEAREST)
+
+ def _refresh_inpaint_preview(self):
+ if self.base_bgr is None:
+ return
+ H, W = self.base_bgr.shape[:2]
+ total_mask = np.zeros((H, W), dtype=bool)
+ for L in self.layers:
+ if not L.has_polygon() or L.is_external:
+ continue
+ poly0 = L.polygon_xy.astype(np.int32)
+ mask = np.zeros((H, W), dtype=np.uint8)
+ cv2.fillPoly(mask, [poly0], 255)
+ total_mask |= (mask > 0)
+ inpainted = inpaint_background(self.base_bgr, total_mask)
+ self.base_preview_bgr = inpainted.copy()
+ if self.base_item is not None:
+ self.base_item.setPixmap(np_bgr_to_qpixmap(self.base_preview_bgr))
+
+ # -------- Scene expansion helpers --------
+ def _expand_scene_to_item(self, item: QtWidgets.QGraphicsItem, margin: int = 120, center: bool = True):
+ if item is None:
+ return
+ try:
+ local_rect = item.boundingRect()
+ poly = item.mapToScene(local_rect)
+ r = poly.boundingRect()
+ except Exception:
+ r = self.scene.sceneRect()
+ sr = self.scene.sceneRect().united(r.adjusted(-margin, -margin, margin, margin))
+ self.scene.setSceneRect(sr)
+ self.ensureVisible(r.adjusted(-20, -20, 20, 20))
+ if center:
+ try:
+ self.centerOn(item)
+ except Exception:
+ pass
+
+ # -------- Base image --------
+ def set_base_image(self, bgr_original: np.ndarray):
+ self.scene.clear()
+ for L in self.layers:
+ L.handle_items.clear(); L.path_lines.clear(); L.preview_line = None
+ self.layers.clear(); self.current_layer = None
+ base_for_save = resize_then_center_crop(bgr_original, self.target_h, self.target_w, interpolation=cv2.INTER_AREA)
+ self.base_bgr = base_for_save.copy()
+ self.base_preview_bgr = self._apply_fit(bgr_original)
+ pm = np_bgr_to_qpixmap(self.base_preview_bgr)
+ self.base_item = self.scene.addPixmap(pm)
+ self.base_item.setZValue(0)
+ self.base_item.setTransformationMode(Qt.FastTransformation) # NN
+ self.setSceneRect(0, 0, pm.width(), pm.height())
+
+ # -------- External sprite layer (no keyframe yet) --------
+ def add_external_sprite_layer(self, raw_bgr: np.ndarray) -> 'Layer':
+ if self.base_bgr is None:
+ return None
+ H, W = self.base_bgr.shape[:2]
+ h0, w0 = raw_bgr.shape[:2]
+ target_h = int(0.6 * H)
+ scale = target_h / float(h0)
+ ew = int(round(w0 * scale))
+ eh = int(round(h0 * scale))
+ ext_small = cv2.resize(raw_bgr, (ew, eh), interpolation=cv2.INTER_AREA)
+
+ # pack onto same canvas size as base
+ px = (W - ew) // 2
+ py = (H - eh) // 2
+ source_bgr = np.zeros((H, W, 3), dtype=np.uint8)
+ x0 = max(px, 0); y0 = max(py, 0)
+ x1 = min(px + ew, W); y1 = min(py + eh, H)
+ if x0 < x1 and y0 < y1:
+ sx0 = x0 - px; sy0 = y0 - py
+ sx1 = sx0 + (x1 - x0); sy1 = sy0 + (y1 - y0)
+ source_bgr[y0:y1, x0:x1] = ext_small[sy0:sy1, sx0:sx1]
+
+ rect_poly = np.array([[px, py], [px+ew, py], [px+ew, py+eh], [px, py+eh]], dtype=np.float32)
+ cx, cy = px + ew/2.0, py + eh/2.0
+
+ mask_rect = np.zeros((H, W), dtype=np.uint8)
+ cv2.fillPoly(mask_rect, [rect_poly.astype(np.int32)], 255)
+ rgba = np.dstack([cv2.cvtColor(source_bgr, cv2.COLOR_BGR2RGB), mask_rect])
+ pm = np_rgba_to_qpixmap(rgba)
+
+ color = PALETTE[self.layer_index % len(PALETTE)]; self.layer_index += 1
+ L = Layer(name=f"Layer {len(self.layers)+1} (ext)", source_bgr=source_bgr, is_external=True,
+ polygon_xy=rect_poly.copy(), origin_local_xy=np.array([cx, cy], dtype=np.float32), color=color)
+ self.layers.append(L); self.current_layer = L
+
+ def on_change():
+ if L.keyframes:
+ self._ensure_preview_line(L)
+ self._relayout_handles(L)
+
+ item = NotifyingPixmapItem(pm, on_change_cb=on_change)
+ item.setZValue(10 + len(self.layers))
+ item.setFlag(QGraphicsPixmapItem.ItemIsMovable, True) # place first
+ item.setFlag(QGraphicsPixmapItem.ItemIsSelectable, False)
+ item.setTransformationMode(Qt.FastTransformation)
+ item.setShapeMode(QGraphicsPixmapItem.ShapeMode.MaskShape)
+ item.setTransformOriginPoint(QPointF(cx, cy))
+ self.scene.addItem(item); L.pixmap_item = item
+
+ # start slightly to the left, still visible
+ min_vis = min(max(40, ew // 5), W // 2)
+ outside_x = (min_vis - (px + ew))
+ item.setPos(outside_x, 0)
+
+ # rect outline + handles (for initial placement)
+ qpath = QPainterPath(QPointF(rect_poly[0,0], rect_poly[0,1]))
+ for i in range(1, len(rect_poly)): qpath.lineTo(QPointF(rect_poly[i,0], rect_poly[i,1]))
+ qpath.closeSubpath()
+ outline = QGraphicsPathItem(qpath, parent=item)
+ outline.setPen(QPen(L.color, 2, Qt.DashLine))
+ outline.setZValue(item.zValue() + 1)
+ L.outline_item = outline
+ self._create_handles_for_layer(L)
+
+ self._expand_scene_to_item(item, center=True)
+ return L
+
+ def place_external_initial_keyframe(self, L: 'Layer'):
+ if not (L and L.pixmap_item): return
+ origin_scene = L.pixmap_item.mapToScene(L.pixmap_item.transformOriginPoint())
+ L.keyframes.append(Keyframe(
+ pos=np.array([origin_scene.x(), origin_scene.y()], dtype=np.float32),
+ rot_deg=float(L.pixmap_item.rotation()),
+ scale=float(L.pixmap_item.scale()) if L.pixmap_item.scale()!=0 else 1.0,
+ hue_deg=0.0
+ ))
+ self._ensure_preview_line(L)
+
+ # -------- Polygon authoring --------
+ def new_layer_from_source(self, name: str, source_bgr: np.ndarray, is_external: bool):
+ color = PALETTE[self.layer_index % len(PALETTE)]; self.layer_index += 1
+ layer = Layer(name=name, source_bgr=source_bgr.copy(), is_external=is_external, color=color)
+ self.layers.append(layer); self.current_layer = layer
+ self.start_draw_polygon(preserve_motion=False)
+
+ def start_draw_polygon(self, preserve_motion: bool):
+ L = self.current_layer
+ if L is None: return
+ if preserve_motion:
+ for it in L.handle_items:
+ it.setVisible(False)
+ else:
+ # We are re-drawing on the current (base) layer: clear visuals first.
+ # REMOVE CHILDREN FIRST (outline/handles/preview), THEN parent pixmap.
+ if L.outline_item is not None:
+ self._remove_if_in_scene(L.outline_item)
+ L.outline_item = None
+ for it in L.handle_items:
+ self._remove_if_in_scene(it)
+ L.handle_items = []
+ if L.preview_line is not None:
+ self._remove_if_in_scene(L.preview_line)
+ L.preview_line = None
+ if L.pixmap_item is not None:
+ self._remove_if_in_scene(L.pixmap_item)
+ L.pixmap_item = None
+
+ L.path_lines = []
+ L.keyframes.clear()
+ L.polygon_xy = None
+ self.mode = Canvas.MODE_DRAW_POLY
+ self.temp_points = []
+ if self.temp_path_item is not None:
+ self.scene.removeItem(self.temp_path_item); self.temp_path_item = None
+ if self.first_click_marker is not None:
+ self.scene.removeItem(self.first_click_marker); self.first_click_marker = None
+ # reset hue preview for new segment
+ self.current_segment_hue_deg = 0.0
+
+ def _compute_ext_rect_from_source(self, src_bgr: np.ndarray) -> np.ndarray:
+ ys, xs = np.where(np.any(src_bgr != 0, axis=2))
+ if len(xs) == 0 or len(ys) == 0:
+ return np.array([[0,0],[0,0],[0,0],[0,0]], dtype=np.float32)
+ x0, x1 = int(xs.min()), int(xs.max())
+ y0, y1 = int(ys.min()), int(ys.max())
+ return np.array([[x0,y0],[x1,y0],[x1,y1],[x0,y1]], dtype=np.float32)
+
+ def _make_rgba_from_bgr_and_maskpoly(self, bgr: np.ndarray, poly: np.ndarray) -> np.ndarray:
+ H, W = bgr.shape[:2]
+ mask = np.zeros((H, W), dtype=np.uint8)
+ if poly is not None and poly.size:
+ cv2.fillPoly(mask, [poly.astype(np.int32)], 255)
+ rgba = np.dstack([cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB), mask])
+ return rgba
+
+ def _add_static_external_item(self, bgr_inpainted: np.ndarray, rect_poly: np.ndarray,
+ kf0: 'Keyframe', z_under: float, color: QColor) -> QGraphicsPixmapItem:
+ rgba = self._make_rgba_from_bgr_and_maskpoly(bgr_inpainted, rect_poly)
+ pm = np_rgba_to_qpixmap(rgba)
+ item = QGraphicsPixmapItem(pm)
+ item.setZValue(max(1.0, z_under - 0.2))
+ item.setTransformationMode(Qt.FastTransformation)
+ item.setShapeMode(QGraphicsPixmapItem.ShapeMode.MaskShape)
+ cx = (rect_poly[:,0].min() + rect_poly[:,0].max())/2.0
+ cy = (rect_poly[:,1].min() + rect_poly[:,1].max())/2.0
+ item.setTransformOriginPoint(QPointF(cx, cy))
+ self.scene.addItem(item)
+ self._apply_pose_from_origin_scene(item, QPointF(kf0.pos[0], kf0.pos[1]), kf0.rot_deg, kf0.scale)
+ path = QPainterPath(QPointF(rect_poly[0,0], rect_poly[0,1]))
+ for i in range(1, len(rect_poly)): path.lineTo(QPointF(rect_poly[i,0], rect_poly[i,1]))
+ path.closeSubpath()
+ outline = QGraphicsPathItem(path, parent=item)
+ outline.setPen(QPen(color, 1, Qt.DashLine))
+ outline.setZValue(item.zValue() + 0.01)
+ self._expand_scene_to_item(item, center=False)
+ return item
+
+ def _update_current_item_hue_preview(self):
+ """Live hue preview for current moving polygon sprite."""
+ L = self.current_layer
+ if not (L and L.pixmap_item and L.polygon_xy is not None):
+ return
+ # Rebuild pixmap of moving item with current hue (full strength preview)
+ bgr = L.source_bgr
+ if abs(self.current_segment_hue_deg) > 1e-6:
+ bgr = apply_hue_shift_bgr(bgr, self.current_segment_hue_deg)
+ rgba = self._make_rgba_from_bgr_and_maskpoly(bgr, L.polygon_xy)
+ L.pixmap_item.setPixmap(np_rgba_to_qpixmap(rgba))
+
+ def finish_polygon(self, preserve_motion: bool) -> bool:
+ L = self.current_layer
+ if L is None or self.mode != Canvas.MODE_DRAW_POLY: return False
+ if len(self.temp_points) < 3: return False
+
+ pts_scene = [QtCore.QPointF(p) for p in self.temp_points]
+
+ if preserve_motion and L.pixmap_item is not None and L.is_external:
+ # ===== EXTERNAL split: remove old rect item, add static rect-with-hole and new moving polygon =====
+ old_item = L.pixmap_item
+
+ # polygon in the old item's LOCAL coords (source_bgr space)
+ pts_local_qt = [old_item.mapFromScene(p) for p in pts_scene]
+ pts_local = np.array([[p.x(), p.y()] for p in pts_local_qt], dtype=np.float32)
+
+ # origin for moving poly = polygon bbox center
+ x0, y0 = pts_local.min(axis=0)
+ x1, y1 = pts_local.max(axis=0)
+ cx_local, cy_local = (x0 + x1) / 2.0, (y0 + y1) / 2.0
+
+ rect_poly_prev = (L.polygon_xy.copy()
+ if (L.polygon_xy is not None and len(L.polygon_xy) >= 3)
+ else self._compute_ext_rect_from_source(L.source_bgr))
+
+ # cache pose / z
+ old_origin_scene = old_item.mapToScene(old_item.transformOriginPoint())
+ old_rot = old_item.rotation()
+ old_scale = old_item.scale() if old_item.scale() != 0 else 1.0
+ old_z = old_item.zValue()
+
+ # build RGBA for moving polygon & static rect-with-hole
+ H, W = L.source_bgr.shape[:2]
+ rgb_full = cv2.cvtColor(L.source_bgr, cv2.COLOR_BGR2RGB)
+
+ mov_mask = np.zeros((H, W), dtype=np.uint8); cv2.fillPoly(mov_mask, [pts_local.astype(np.int32)], 255)
+ mov_rgba = np.dstack([rgb_full, mov_mask])
+
+ hole_mask = np.zeros((H, W), dtype=np.uint8); cv2.fillPoly(hole_mask, [pts_local.astype(np.int32)], 255)
+ inpainted_ext = inpaint_background(L.source_bgr, hole_mask > 0)
+ rect_mask = np.zeros((H, W), dtype=np.uint8); cv2.fillPoly(rect_mask, [rect_poly_prev.astype(np.int32)], 255)
+ static_rgba = np.dstack([cv2.cvtColor(inpainted_ext, cv2.COLOR_BGR2RGB), rect_mask])
+
+ # remove old outline/handles and the old item itself
+ if L.outline_item is not None:
+ self._remove_if_in_scene(L.outline_item); L.outline_item = None
+ for it in L.handle_items:
+ self._remove_if_in_scene(it)
+ L.handle_items = []
+ self._remove_if_in_scene(old_item)
+
+ # STATIC rect (non-movable, below)
+ kf0 = L.keyframes[0] if L.keyframes else Keyframe(
+ pos=np.array([old_origin_scene.x(), old_origin_scene.y()], dtype=np.float32),
+ rot_deg=old_rot, scale=old_scale, hue_deg=0.0
+ )
+ static_item = QGraphicsPixmapItem(np_rgba_to_qpixmap(static_rgba))
+ static_item.setZValue(max(1.0, old_z - 0.2))
+ static_item.setTransformationMode(Qt.FastTransformation)
+ static_item.setShapeMode(QGraphicsPixmapItem.ShapeMode.MaskShape)
+ rcx = (rect_poly_prev[:,0].min() + rect_poly_prev[:,0].max())/2.0
+ rcy = (rect_poly_prev[:,1].min() + rect_poly_prev[:,1].max())/2.0
+ static_item.setTransformOriginPoint(QPointF(rcx, rcy))
+ self.scene.addItem(static_item)
+ self._apply_pose_from_origin_scene(static_item, QPointF(kf0.pos[0], kf0.pos[1]), kf0.rot_deg, kf0.scale)
+ # dashed outline for static
+ qpath_rect = QPainterPath(QPointF(rect_poly_prev[0,0], rect_poly_prev[0,1]))
+ for i in range(1, len(rect_poly_prev)): qpath_rect.lineTo(QPointF(rect_poly_prev[i,0], rect_poly_prev[i,1]))
+ qpath_rect.closeSubpath()
+ outline_static = QGraphicsPathItem(qpath_rect, parent=static_item)
+ outline_static.setPen(QPen(L.color, 1, Qt.DashLine))
+ outline_static.setZValue(static_item.zValue() + 0.01)
+
+ # NEW MOVING polygon (on top)
+ def on_change():
+ if L.keyframes:
+ self._ensure_preview_line(L)
+ self._relayout_handles(L)
+
+ poly_item = NotifyingPixmapItem(np_rgba_to_qpixmap(mov_rgba), on_change_cb=on_change)
+ poly_item.setZValue(old_z + 0.2)
+ poly_item.setFlag(QGraphicsPixmapItem.ItemIsMovable, True)
+ poly_item.setFlag(QGraphicsPixmapItem.ItemIsSelectable, False)
+ poly_item.setTransformationMode(Qt.FastTransformation)
+ poly_item.setShapeMode(QGraphicsPixmapItem.ShapeMode.MaskShape)
+ poly_item.setTransformOriginPoint(QPointF(cx_local, cy_local))
+ self.scene.addItem(poly_item)
+ self._apply_pose_from_origin_scene(poly_item, old_origin_scene, old_rot, old_scale)
+
+ # outline/handles on moving polygon
+ qpath = QPainterPath(QPointF(pts_local[0,0], pts_local[0,1]))
+ for i in range(1, len(pts_local)): qpath.lineTo(QPointF(pts_local[i,0], pts_local[i,1]))
+ qpath.closeSubpath()
+ outline_move = QGraphicsPathItem(qpath, parent=poly_item)
+ outline_move.setPen(QPen(L.color, 2))
+ outline_move.setZValue(poly_item.zValue() + 1)
+
+ L.polygon_xy = pts_local
+ L.origin_local_xy = np.array([cx_local, cy_local], dtype=np.float32)
+ L.pixmap_item = poly_item
+ L.outline_item = outline_move
+ self._create_handles_for_layer(L)
+ self._ensure_preview_line(L)
+
+ # live hue preview starts neutral for new poly
+ self.current_segment_hue_deg = 0.0
+ self._expand_scene_to_item(poly_item, center=True)
+
+ else:
+ # ===== BASE image polygon path =====
+ pts = np.array([[p.x(), p.y()] for p in pts_scene], dtype=np.float32)
+ x0, y0 = pts.min(axis=0); x1, y1 = pts.max(axis=0)
+ cx, cy = (x0+x1)/2.0, (y0+y1)/2.0
+
+ L.polygon_xy = pts
+ L.origin_local_xy = np.array([cx, cy], dtype=np.float32)
+
+ rgb = cv2.cvtColor(L.source_bgr, cv2.COLOR_BGR2RGB)
+ h, w = rgb.shape[:2]
+ mask = np.zeros((h, w), dtype=np.uint8)
+ cv2.fillPoly(mask, [pts.astype(np.int32)], 255)
+ rgba = np.dstack([rgb, mask])
+ pm = np_rgba_to_qpixmap(rgba)
+
+ def on_change():
+ if L.keyframes:
+ self._ensure_preview_line(L)
+ self._relayout_handles(L)
+ item = NotifyingPixmapItem(pm, on_change_cb=on_change)
+ item.setZValue(10 + len(self.layers))
+ item.setFlag(QGraphicsPixmapItem.ItemIsMovable, True)
+ item.setFlag(QGraphicsPixmapItem.ItemIsSelectable, False)
+ item.setTransformationMode(Qt.FastTransformation)
+ item.setShapeMode(QGraphicsPixmapItem.ShapeMode.MaskShape)
+ item.setTransformOriginPoint(QPointF(cx, cy))
+ self.scene.addItem(item); L.pixmap_item = item
+
+ qpath = QPainterPath(QPointF(pts[0,0], pts[0,1]))
+ for i in range(1, len(pts)): qpath.lineTo(QPointF(pts[i,0], pts[i,1]))
+ qpath.closeSubpath()
+ outline = QGraphicsPathItem(qpath, parent=item)
+ outline.setPen(QPen(L.color, 2))
+ outline.setZValue(item.zValue() + 1)
+ L.outline_item = outline
+ self._create_handles_for_layer(L)
+
+ origin_scene = item.mapToScene(item.transformOriginPoint())
+ L.keyframes.append(Keyframe(pos=np.array([origin_scene.x(), origin_scene.y()], dtype=np.float32),
+ rot_deg=float(item.rotation()),
+ scale=float(item.scale()) if item.scale()!=0 else 1.0,
+ hue_deg=0.0))
+
+ if not (L.is_external):
+ self._refresh_inpaint_preview()
+
+ if self.temp_path_item is not None: self._remove_if_in_scene(self.temp_path_item); self.temp_path_item = None
+ if self.first_click_marker is not None: self._remove_if_in_scene(self.first_click_marker); self.first_click_marker = None
+ self.temp_points = []
+ self.mode = Canvas.MODE_IDLE
+ # ensure the current hue preview is applied (neutral at first)
+ self._update_current_item_hue_preview()
+ return True
+
+ # -------- UI helpers --------
+ def _create_handles_for_layer(self, L: Layer):
+ if L.polygon_xy is None or L.pixmap_item is None:
+ return
+ x0, y0 = L.polygon_xy.min(axis=0)
+ x1, y1 = L.polygon_xy.max(axis=0)
+ corners = [QPointF(x0,y0), QPointF(x1,y0), QPointF(x1,y1), QPointF(x0,y1)]
+ top_center = QPointF((x0+x1)/2.0, y0)
+ rot_pos = QPointF(top_center.x(), top_center.y() - 24)
+
+ box_path = QPainterPath(corners[0])
+ for p in corners[1:]:
+ box_path.lineTo(p)
+ box_path.closeSubpath()
+ # bbox (dashed) around the polygon
+ bbox_item = QGraphicsPathItem(box_path, parent=L.pixmap_item)
+ pen = QPen(L.color, 1, Qt.DashLine)
+ pen.setCosmetic(True)
+ bbox_item.setPen(pen)
+ bbox_item.setZValue(L.pixmap_item.zValue() + 0.5)
+ L.handle_items.append(bbox_item)
+
+ for c in corners:
+ h = ScaleHandle(6, L.color, parent=L.pixmap_item)
+ h.setPos(c); h.set_item(L.pixmap_item)
+ L.handle_items.append(h)
+ rot_dot = RotateHandle(6, L.color, parent=L.pixmap_item)
+ rot_dot.setPos(rot_pos); rot_dot.set_item(L.pixmap_item)
+ L.handle_items.append(rot_dot)
+ tether = QGraphicsLineItem(QtCore.QLineF(top_center, rot_pos), L.pixmap_item)
+ pen_tether = QPen(L.color, 1)
+ pen_tether.setCosmetic(True)
+ tether.setPen(pen_tether)
+ tether.setZValue(L.pixmap_item.zValue() + 0.4)
+ L.handle_items.append(tether)
+
+ def _relayout_handles(self, L: Layer):
+ if L.polygon_xy is None or L.pixmap_item is None or not L.handle_items:
+ return
+ x0, y0 = L.polygon_xy.min(axis=0); x1, y1 = L.polygon_xy.max(axis=0)
+ corners = [QPointF(x0,y0), QPointF(x1,y0), QPointF(x1,y1), QPointF(x0,y1)]
+ top_center = QPointF((x0+x1)/2.0, y0)
+ rot_pos = QPointF(top_center.x(), top_center.y() - 24)
+ bbox_item = L.handle_items[0]
+ if isinstance(bbox_item, QGraphicsPathItem):
+ box_path = QPainterPath(corners[0])
+ for p in corners[1:]: box_path.lineTo(p)
+ box_path.closeSubpath(); bbox_item.setPath(box_path)
+ for i in range(4):
+ h = L.handle_items[1+i]
+ if isinstance(h, QGraphicsEllipseItem):
+ h.setPos(corners[i])
+ rot_dot = L.handle_items[5]
+ if isinstance(rot_dot, QGraphicsEllipseItem):
+ rot_dot.setPos(rot_pos)
+ tether = L.handle_items[6]
+ if isinstance(tether, QGraphicsLineItem):
+ tether.setLine(QtCore.QLineF(top_center, rot_pos))
+
+ def _ensure_preview_line(self, L: Layer):
+ if L.pixmap_item is None or not L.keyframes:
+ return
+ origin_scene = L.pixmap_item.mapToScene(L.pixmap_item.transformOriginPoint())
+ p0 = L.keyframes[-1].pos
+ p1 = np.array([origin_scene.x(), origin_scene.y()], dtype=np.float32)
+ if L.preview_line is None:
+ line = QGraphicsLineItem(p0[0], p0[1], p1[0], p1[1])
+ line.setPen(QPen(L.color, 1, Qt.DashLine))
+ line.setZValue(950)
+ self.scene.addItem(line)
+ L.preview_line = line
+ else:
+ L.preview_line.setLine(p0[0], p0[1], p1[0], p1[1])
+
+ def _update_temp_path_item(self, color: QColor):
+ if self.temp_path_item is None:
+ self.temp_path_item = QGraphicsPathItem()
+ pen = QPen(color, 2)
+ self.temp_path_item.setPen(pen)
+ self.temp_path_item.setZValue(1000)
+ self.scene.addItem(self.temp_path_item)
+ if not self.temp_points:
+ self.temp_path_item.setPath(QPainterPath())
+ return
+ path = QPainterPath(self.temp_points[0])
+ for p in self.temp_points[1:]:
+ path.lineTo(p)
+ path.lineTo(self.temp_points[0])
+ self.temp_path_item.setPath(path)
+
+ # -------------- Mouse / Keys for polygon drawing --------------
+ def mousePressEvent(self, event):
+ # Right-click = End Segment (ONLY when not drawing a polygon)
+ if self.mode != Canvas.MODE_DRAW_POLY and event.button() == Qt.RightButton:
+ self.end_segment_requested.emit()
+ event.accept()
+ return
+
+ if self.mode == Canvas.MODE_DRAW_POLY:
+ try:
+ p = event.position()
+ scene_pos = self.mapToScene(int(p.x()), int(p.y()))
+ except AttributeError:
+ scene_pos = self.mapToScene(event.pos())
+
+ if event.button() == Qt.LeftButton:
+ self.temp_points.append(scene_pos)
+ if len(self.temp_points) == 1:
+ if self.first_click_marker is not None:
+ self.scene.removeItem(self.first_click_marker)
+ self.first_click_marker = QGraphicsEllipseItem(-3, -3, 6, 6)
+ self.first_click_marker.setBrush(QColor(0, 220, 0))
+ self.first_click_marker.setPen(QPen(QColor(0, 0, 0), 1))
+ self.first_click_marker.setZValue(1200)
+ self.scene.addItem(self.first_click_marker)
+ self.first_click_marker.setPos(scene_pos)
+ color = self.current_layer.color if self.current_layer else QColor(255, 0, 0)
+ self._update_temp_path_item(color)
+ elif event.button() == Qt.RightButton:
+ # Finish polygon with right-click
+ preserve = (
+ self.current_layer is not None
+ and self.current_layer.pixmap_item is not None
+ and self.current_layer.is_external
+ )
+ ok = self.finish_polygon(preserve_motion=preserve)
+ self.polygon_finished.emit(ok)
+ if not ok:
+ QMessageBox.information(self, "Polygon", "Need at least 3 points.")
+ event.accept()
+ return
+
+ return
+
+ super().mousePressEvent(event)
+
+
+ def mouseDoubleClickEvent(self, event):
+ if self.mode == Canvas.MODE_DRAW_POLY:
+ return
+ super().mouseDoubleClickEvent(event)
+
+ def keyPressEvent(self, event: QtGui.QKeyEvent):
+ if self.mode == Canvas.MODE_DRAW_POLY:
+ if event.key() in (Qt.Key_Return, Qt.Key_Enter):
+ # Polygons are finished with right-click now; ignore Enter.
+ return
+ elif event.key() == Qt.Key_Backspace:
+ if self.temp_points:
+ self.temp_points.pop()
+ color = self.current_layer.color if self.current_layer else QColor(255,0,0)
+ self._update_temp_path_item(color)
+ return
+ elif event.key() == Qt.Key_Escape:
+ if self.temp_path_item is not None: self.scene.removeItem(self.temp_path_item); self.temp_path_item = None
+ if self.first_click_marker is not None: self.scene.removeItem(self.first_click_marker); self.first_click_marker = None
+ self.temp_points = []
+ self.mode = Canvas.MODE_IDLE
+ return
+ super().keyPressEvent(event)
+
+ # keyframes
+ def end_segment_add_keyframe(self):
+ if not (self.current_layer and self.current_layer.pixmap_item and (self.current_layer.polygon_xy is not None) and self.current_layer.keyframes):
+ return False
+ item = self.current_layer.pixmap_item
+ origin_scene = item.mapToScene(item.transformOriginPoint())
+ kf = Keyframe(
+ pos=np.array([origin_scene.x(), origin_scene.y()], dtype=np.float32),
+ rot_deg=float(item.rotation()),
+ scale=float(item.scale()) if item.scale()!=0 else 1.0,
+ hue_deg=float(self.current_segment_hue_deg)
+ )
+ L = self.current_layer
+ if len(L.keyframes) >= 1:
+ p0 = L.keyframes[-1].pos; p1 = kf.pos
+ if L.preview_line is not None:
+ self.scene.removeItem(L.preview_line); L.preview_line = None
+ line = QGraphicsLineItem(p0[0], p0[1], p1[0], p1[1])
+ line.setPen(QPen(L.color, 2)); line.setZValue(900); self.scene.addItem(line)
+ L.path_lines.append(line)
+ L.keyframes.append(kf)
+ self._ensure_preview_line(L)
+ # reset hue for next leg
+ self.current_segment_hue_deg = 0.0
+ # refresh preview back to neutral
+ self._update_current_item_hue_preview()
+ return True
+
+ def has_pending_transform(self) -> bool:
+ L = self.current_layer
+ if not (L and L.pixmap_item and L.keyframes): return False
+ last = L.keyframes[-1]
+ item = L.pixmap_item
+ origin_scene = item.mapToScene(item.transformOriginPoint())
+ pos = np.array([origin_scene.x(), origin_scene.y()], dtype=np.float32)
+ dpos = np.linalg.norm(pos - last.pos)
+ drot = abs(float(item.rotation()) - last.rot_deg)
+ dscale = abs((float(item.scale()) if item.scale()!=0 else 1.0) - last.scale)
+ # hue preview doesn’t count as a “transform” until you end the segment
+ return (dpos > 0.5) or (drot > 0.1) or (dscale > 1e-3)
+
+ def revert_to_last_keyframe(self, L: Optional[Layer] = None):
+ if L is None: L = self.current_layer
+ if not (L and L.pixmap_item and L.keyframes): return
+ last = L.keyframes[-1]
+ item = L.pixmap_item
+ item.setRotation(last.rot_deg); item.setScale(last.scale)
+ origin_scene = item.mapToScene(item.transformOriginPoint())
+ d = QPointF(last.pos[0]-origin_scene.x(), last.pos[1]-origin_scene.y())
+ item.setPos(item.pos() + d)
+ self._ensure_preview_line(L)
+ # restore hue preview to last keyframe hue
+ self.current_segment_hue_deg = last.hue_deg
+ self._update_current_item_hue_preview()
+
+ def _sample_keyframes_constant_speed_with_seg(self, keyframes: List[Keyframe], T: int):
+ """
+ Allocate frames to segments proportional to their Euclidean length so that
+ translation happens at constant speed across the whole path.
+ Returns (pos[T,2], scl[T], rot[T], seg_idx[T], t[T]).
+ """
+ K = len(keyframes)
+ assert K >= 1
+ import math
+ if T <= 0:
+ # degenerate: return just the first pose
+ p0 = keyframes[0].pos.astype(np.float32)
+ return (np.repeat(p0[None, :], 0, axis=0),
+ np.zeros((0,), np.float32),
+ np.zeros((0,), np.float32),
+ np.zeros((0,), np.int32),
+ np.zeros((0,), np.float32))
+
+ if K == 1:
+ p0 = keyframes[0].pos.astype(np.float32)
+ pos = np.repeat(p0[None, :], T, axis=0)
+ scl = np.full((T,), float(keyframes[0].scale), dtype=np.float32)
+ rot = np.full((T,), float(keyframes[0].rot_deg), dtype=np.float32)
+ seg_idx = np.zeros((T,), dtype=np.int32)
+ t = np.zeros((T,), dtype=np.float32)
+ return pos, scl, rot, seg_idx, t
+
+ # Segment lengths (translation only)
+ P = np.array([kf.pos for kf in keyframes], dtype=np.float32) # [K,2]
+ seg_vec = P[1:] - P[:-1] # [K-1,2]
+ lengths = np.linalg.norm(seg_vec, axis=1) # [K-1]
+ total_len = float(lengths.sum())
+
+ def _per_seg_counts_uniform():
+ # fallback: equal frames per segment
+ base = np.zeros((K-1,), dtype=np.int32)
+ if T > 0:
+ # spread as evenly as possible
+ q, r = divmod(T, K-1)
+ base[:] = q
+ base[:r] += 1
+ return base
+
+ if total_len <= 1e-6:
+ counts = _per_seg_counts_uniform()
+ else:
+ # Proportional allocation by length, rounded with largest-remainder
+ raw = (lengths / total_len) * T
+ base = np.floor(raw).astype(np.int32)
+ remainder = T - int(base.sum())
+ if remainder > 0:
+ order = np.argsort(-(raw - base)) # largest fractional parts first
+ base[order[:remainder]] += 1
+ counts = base # may contain zeros for ~zero-length segments
+
+ # Build arrays
+ pos_list, scl_list, rot_list, seg_idx_list, t_list = [], [], [], [], []
+
+ for s in range(K - 1):
+ n = int(counts[s])
+ if n <= 0:
+ continue
+ # Local times in [0,1) to avoid s+1 overflow in hue blending
+ ts = np.linspace(0.0, 1.0, n, endpoint=False, dtype=np.float32)
+
+ p0, p1 = P[s], P[s + 1]
+ s0, s1 = max(1e-6, float(keyframes[s].scale)), max(1e-6, float(keyframes[s + 1].scale))
+ r0, r1 = float(keyframes[s].rot_deg), float(keyframes[s + 1].rot_deg)
+
+ # Interpolate
+ pos_seg = (1 - ts)[:, None] * p0[None, :] + ts[:, None] * p1[None, :]
+ scl_seg = np.exp((1 - ts) * math.log(s0) + ts * math.log(s1))
+ rot_seg = (1 - ts) * r0 + ts * r1
+
+ pos_list.append(pos_seg.astype(np.float32))
+ scl_list.append(scl_seg.astype(np.float32))
+ rot_list.append(rot_seg.astype(np.float32))
+ seg_idx_list.append(np.full((n,), s, dtype=np.int32))
+ t_list.append(ts.astype(np.float32))
+
+ # If counts summed to < T (can happen if T < #segments), pad with final pose of last used seg
+ N = sum(int(c) for c in counts)
+ if N < T:
+ # use final keyframe as hold
+ p_end = P[-1].astype(np.float32)
+ extra = T - N
+ pos_list.append(np.repeat(p_end[None, :], extra, axis=0))
+ scl_list.append(np.full((extra,), float(keyframes[-1].scale), dtype=np.float32))
+ rot_list.append(np.full((extra,), float(keyframes[-1].rot_deg), dtype=np.float32))
+ # Use the last real segment index (K-2), with t=0 (blend start of final seg)
+ seg_idx_list.append(np.full((extra,), max(0, K - 2), dtype=np.int32))
+ t_list.append(np.zeros((extra,), dtype=np.float32))
+
+ pos = np.vstack(pos_list) if pos_list else np.zeros((T, 2), dtype=np.float32)
+ scl = np.concatenate(scl_list) if scl_list else np.zeros((T,), dtype=np.float32)
+ rot = np.concatenate(rot_list) if rot_list else np.zeros((T,), dtype=np.float32)
+ seg_idx = np.concatenate(seg_idx_list) if seg_idx_list else np.zeros((T,), dtype=np.int32)
+ t = np.concatenate(t_list) if t_list else np.zeros((T,), dtype=np.float32)
+
+ # Truncate in case of rounding over-alloc (very rare), or pad if still short
+ if len(pos) > T:
+ pos, scl, rot, seg_idx, t = pos[:T], scl[:T], rot[:T], seg_idx[:T], t[:T]
+ elif len(pos) < T:
+ pad = T - len(pos)
+ pos = np.vstack([pos, np.repeat(pos[-1:,:], pad, axis=0)])
+ scl = np.concatenate([scl, np.repeat(scl[-1:], pad)])
+ rot = np.concatenate([rot, np.repeat(rot[-1:], pad)])
+ seg_idx = np.concatenate([seg_idx, np.repeat(seg_idx[-1:], pad)])
+ t = np.concatenate([t, np.repeat(t[-1:], pad)])
+
+ return pos.astype(np.float32), scl.astype(np.float32), rot.astype(np.float32), seg_idx.astype(np.int32), t.astype(np.float32)
+
+
+ def undo(self) -> bool:
+ if self.mode == Canvas.MODE_DRAW_POLY and self.temp_points:
+ self.temp_points.pop()
+ color = self.current_layer.color if self.current_layer else QColor(255,0,0)
+ self._update_temp_path_item(color)
+ return True
+ if self.has_pending_transform():
+ self.revert_to_last_keyframe()
+ return True
+ if self.current_layer and len(self.current_layer.keyframes) > 1:
+ L = self.current_layer
+ if L.path_lines:
+ line = L.path_lines.pop(); self.scene.removeItem(line)
+ L.keyframes.pop()
+ self.revert_to_last_keyframe(L)
+ return True
+ if self.current_layer:
+ L = self.current_layer
+ if (L.pixmap_item is not None) and (len(L.keyframes) <= 1) and (len(L.path_lines) == 0):
+ if L.preview_line is not None:
+ self.scene.removeItem(L.preview_line); L.preview_line = None
+ if L.outline_item is not None:
+ self.scene.removeItem(L.outline_item); L.outline_item = None
+ for it in L.handle_items:
+ self.scene.removeItem(it)
+ L.handle_items.clear()
+ self.scene.removeItem(L.pixmap_item); L.pixmap_item = None
+ try:
+ idx = self.layers.index(L)
+ self.layers.pop(idx)
+ except ValueError:
+ pass
+ self.current_layer = self.layers[-1] if self.layers else None
+ if (L.is_external is False):
+ self._refresh_inpaint_preview()
+ return True
+ return False
+
+ # -------- Demo playback (with hue crossfade) --------
+ def build_preview_frames(self, T_total: int) -> Optional[List[np.ndarray]]:
+ if self.base_bgr is None:
+ return None
+ H, W = self.base_bgr.shape[:2]
+ total_mask = np.zeros((H, W), dtype=bool)
+ for L in self.layers:
+ if not L.has_polygon() or L.is_external:
+ continue
+ poly0 = L.polygon_xy.astype(np.int32)
+ mask = np.zeros((H, W), dtype=np.uint8); cv2.fillPoly(mask, [poly0], 255)
+ total_mask |= (mask > 0)
+ background = inpaint_background(self.base_bgr, total_mask)
+
+ all_layer_frames = []
+ has_any = False
+
+ # def sample_keyframes_uniform_with_seg(keyframes: List[Keyframe], T: int):
+ # K = len(keyframes); assert K >= 1
+ # if K == 1:
+ # pos = np.repeat(keyframes[0].pos[None, :], T, axis=0).astype(np.float32)
+ # scl = np.full((T,), keyframes[0].scale, dtype=np.float32)
+ # rot = np.full((T,), keyframes[0].rot_deg, dtype=np.float32)
+ # seg_idx = np.zeros((T,), dtype=np.int32)
+ # t = np.zeros((T,), dtype=np.float32)
+ # return pos, scl, rot, seg_idx, t
+ # segs = K - 1
+ # u = np.linspace(0.0, float(segs), T, dtype=np.float32)
+ # seg_idx = np.minimum(np.floor(u).astype(int), segs - 1)
+ # t = u - seg_idx
+ # k0 = np.array([[keyframes[i].pos[0], keyframes[i].pos[1], keyframes[i].scale, keyframes[i].rot_deg] for i in seg_idx], dtype=np.float32)
+ # k1 = np.array([[keyframes[i+1].pos[0], keyframes[i+1].pos[1], keyframes[i+1].scale, keyframes[i+1].rot_deg] for i in seg_idx], dtype=np.float32)
+ # pos0 = k0[:, :2]; pos1 = k1[:, :2]
+ # s0 = np.maximum(1e-6, k0[:, 2]); s1 = np.maximum(1e-6, k1[:, 2])
+ # r0 = k0[:, 3]; r1 = k1[:, 3]
+ # pos = (1 - t)[:, None] * pos0 + t[:, None] * pos1
+ # scl = np.exp((1 - t) * np.log(s0) + t * np.log(s1))
+ # rot = (1 - t) * r0 + t * r1
+ # return pos.astype(np.float32), scl.astype(np.float32), rot.astype(np.float32), seg_idx, t
+
+ for L in self.layers:
+ if not L.has_polygon() or len(L.keyframes) < 2:
+ continue
+ has_any = True
+
+ # path_xy, scales, rots, seg_idx, t = sample_keyframes_uniform_with_seg(L.keyframes, T_total)
+ path_xy, scales, rots, seg_idx, t = self._sample_keyframes_constant_speed_with_seg(L.keyframes, T_total)
+
+ origin_xy = L.origin_local_xy if L.origin_local_xy is not None else L.polygon_xy.mean(axis=0)
+
+ # Precompute animations for each keyframe’s hue (crossfade per segment)
+ K = len(L.keyframes)
+ hue_values = [L.keyframes[k].hue_deg for k in range(K)]
+ hue_to_frames: Dict[int, List[np.ndarray]] = {}
+ for k in range(K):
+ bgr_h = apply_hue_shift_bgr(L.source_bgr, hue_values[k])
+ frames_h, _ = animate_polygon(bgr_h, L.polygon_xy, path_xy, scales, rots,
+ interp=cv2.INTER_LINEAR, origin_xy=origin_xy)
+ hue_to_frames[k] = frames_h
+
+ # Mix per frame using seg_idx / t
+ frames_rgba = []
+ for i in range(T_total):
+ s = int(seg_idx[i])
+ w = float(t[i])
+ A = hue_to_frames[s][i].astype(np.float32)
+ B = hue_to_frames[s+1][i].astype(np.float32)
+ mix = (1.0 - w) * A + w * B
+ frames_rgba.append(np.clip(mix, 0, 255).astype(np.uint8))
+ all_layer_frames.append(frames_rgba)
+
+ if not has_any:
+ return None
+ frames_out = composite_frames(background, all_layer_frames)
+ return frames_out
+
+ def play_demo(self, fps: int, T_total: int):
+ frames = self.build_preview_frames(T_total)
+ if not frames:
+ QMessageBox.information(self, "Play Demo", "Nothing to play yet. Add a polygon and keyframes.")
+ return
+ self.play_frames = frames
+ self.play_index = 0
+ if self.player_item is None:
+ self.player_item = QGraphicsPixmapItem()
+ self.player_item.setZValue(5000)
+ self.scene.addItem(self.player_item)
+ self.player_item.setVisible(True)
+ self._on_play_tick()
+ interval_ms = max(1, int(1000 / max(1, fps)))
+ self.play_timer.start(interval_ms)
+
+ def _on_play_tick(self):
+ if not self.play_frames or self.play_index >= len(self.play_frames):
+ self.play_timer.stop()
+ if self.player_item is not None:
+ self.player_item.setVisible(False)
+ return
+ frame = self.play_frames[self.play_index]
+ self.play_index += 1
+ self.player_item.setPixmap(np_bgr_to_qpixmap(frame))
+
+# ------------------------------
+# Main window / controls
+# ------------------------------
+
+class MainWindow(QMainWindow):
+ def __init__(self):
+ super().__init__()
+ self.setWindowTitle("Time-to-Move: Cut & Drag")
+ self.resize(1180, 840)
+
+ self.canvas = Canvas(self)
+ self.canvas.polygon_finished.connect(self._on_canvas_polygon_finished)
+ self.canvas.end_segment_requested.connect(self.on_end_segment)
+
+ # -------- Instruction banner above canvas (CENTERED) --------
+ self.instruction_label = QLabel()
+ self.instruction_label.setWordWrap(True)
+ self.instruction_label.setAlignment(Qt.AlignHCenter | Qt.AlignVCenter)
+ self.instruction_label.setStyleSheet("""
+ QLabel {
+ background: #f7f7fa;
+ border-bottom: 1px solid #ddd;
+ padding: 10px 12px;
+ font-size: 15px;
+ color: #222;
+ }
+ """)
+ self._set_instruction("Welcome! • Select Image to begin.")
+
+ central = QWidget()
+ v = QVBoxLayout(central)
+ v.setContentsMargins(0,0,0,0); v.setSpacing(0)
+ v.addWidget(self.instruction_label)
+ v.addWidget(self.canvas)
+ self.setCentralWidget(central)
+
+ # state: external placing mode?
+ self.placing_external: bool = False
+ self.placing_layer: Optional[Layer] = None
+
+ # -------- Vertical toolbar on the LEFT --------
+ tb = QToolBar("Tools")
+ self.addToolBar(Qt.LeftToolBarArea, tb)
+ tb.setOrientation(Qt.Vertical)
+
+ def add_btn(text: str, slot, icon: Optional[QIcon] = None):
+ btn = QPushButton(text)
+ if icon: btn.setIcon(icon)
+ btn.setCursor(Qt.PointingHandCursor)
+ btn.setMinimumWidth(180)
+ btn.clicked.connect(slot)
+ tb.addWidget(btn); return btn
+
+ # Fit dropdown (default: Center Crop)
+ self.cmb_fit = QComboBox(); self.cmb_fit.addItems(["Center Crop", "Center Pad"])
+ tb.addWidget(self.cmb_fit)
+ self.canvas.fit_mode_combo = self.cmb_fit
+
+ # Dotted separator
+ line_sep = QFrame(); line_sep.setFrameShape(QFrame.HLine); line_sep.setFrameShadow(QFrame.Plain)
+ line_sep.setStyleSheet("color: #888; border-top: 1px dotted #888; margin: 8px 0;")
+ tb.addWidget(line_sep)
+
+ # Select Image
+ self.btn_select = add_btn("🖼️ Select Image", self.on_select_base)
+
+ # Add Polygon (toggles to Finish)
+ self.pent_icon = self.canvas.make_pentagon_icon()
+ self.btn_add_poly = add_btn("Add Polygon", self.on_add_polygon_toggled, icon=self.pent_icon)
+ self.add_poly_active = False
+
+ # Add External (two-step: file → Place)
+ self.btn_add_external = add_btn("🖼️➕ Add External Image", self.on_add_or_place_external)
+
+ # HUE TRANSFORM (slider + Default) ABOVE End Segment
+ tb.addSeparator()
+ tb.addWidget(QLabel("Hue Transform"))
+ hue_row = QWidget(); row = QHBoxLayout(hue_row); row.setContentsMargins(0,0,0,0)
+ self.sld_hue = QSlider(Qt.Horizontal); self.sld_hue.setRange(-180, 180); self.sld_hue.setValue(0)
+ btn_default = QPushButton("Default"); btn_default.setCursor(Qt.PointingHandCursor); btn_default.setFixedWidth(70)
+ row.addWidget(self.sld_hue, 1); row.addWidget(btn_default, 0)
+ tb.addWidget(hue_row)
+ self.sld_hue.valueChanged.connect(self.on_hue_changed)
+ btn_default.clicked.connect(lambda: self.sld_hue.setValue(0))
+
+ # End Segment and Undo
+ self.btn_end_seg = add_btn("🎯 End Segment", self.on_end_segment)
+ self.btn_undo = add_btn("↩️ Undo", self.on_undo)
+
+ tb.addSeparator()
+ tb.addWidget(QLabel("Total Frames:"))
+ self.spn_total_frames = QSpinBox(); self.spn_total_frames.setRange(1, 2000); self.spn_total_frames.setValue(81)
+ tb.addWidget(self.spn_total_frames)
+ tb.addWidget(QLabel("FPS:"))
+ self.spn_fps = QSpinBox(); self.spn_fps.setRange(1, 120); self.spn_fps.setValue(16)
+ tb.addWidget(self.spn_fps)
+
+ tb.addSeparator()
+ self.btn_play = add_btn("▶️ Play Demo", self.on_play_demo)
+ tb.addWidget(QLabel("Prompt"))
+ self.txt_prompt = QPlainTextEdit()
+ self.txt_prompt.setFixedHeight(80) # ~3–5 lines tall; tweak if you like
+ self.txt_prompt.setMinimumWidth(180) # matches your button width
+ # (Optional) If your PySide6 supports it, you can uncomment the next line:
+ # self.txt_prompt.setPlaceholderText("Write your prompt here…")
+ tb.addWidget(self.txt_prompt)
+ self.btn_save = add_btn("💾 Save", self.on_save)
+ self.btn_new = add_btn("🆕 New", self.on_new)
+ self.btn_exit = add_btn("⏹️ Exit", self.close)
+
+ # Status strip at bottom
+ status = QToolBar("Status")
+ self.addToolBar(Qt.BottomToolBarArea, status)
+ self.status_label = QLabel("Ready")
+ status.addWidget(self.status_label)
+
+ # ---------- Instruction helper ----------
+ def _set_instruction(self, text: str):
+ self.instruction_label.setText(text)
+
+ # ---------- Pending-segment guards ----------
+ def _block_if_pending_segment(self, action_label: str) -> bool:
+ if self.canvas.current_layer and self.canvas.has_pending_transform():
+ QMessageBox.information(
+ self, "Finish Segment",
+ f"Please end the current segment (click '🎯 End Segment') before {action_label}."
+ )
+ self._set_instruction("Finish current segment: drag/scale/rotate as needed, adjust Hue, then click 🎯 End Segment.")
+ return True
+ return False
+
+ # ------------- Actions -------------
+ def on_select_base(self):
+ if self._block_if_pending_segment("changing the base image"):
+ return
+ path, _ = QFileDialog.getOpenFileName(self, "Select image", "", "Images/Videos (*.png *.jpg *.jpeg *.bmp *.mp4 *.mov *.avi *.mkv)")
+ if not path:
+ self._set_instruction("No image selected. Click ‘Select Image’ to begin.")
+ return
+ try:
+ raw = load_first_frame(path)
+ except Exception as e:
+ QMessageBox.critical(self, "Load", f"Failed to load: {e}")
+ return
+ self.canvas.set_base_image(raw)
+ self.add_poly_active = False
+ self.btn_add_poly.setText("Add Polygon")
+ self.placing_external = False; self.placing_layer = None
+ self.btn_add_external.setText("🖼️➕ Add External Image")
+ self.status_label.setText("Base loaded.")
+ self._set_instruction("Step 1: Add a polygon (Add Polygon), or add an external sprite (Add External Image).")
+
+ def on_add_polygon_toggled(self):
+ if self.placing_external:
+ QMessageBox.information(self, "Place External First",
+ "Please place the external image first (click ‘✅ Place External Image’).")
+ self._set_instruction("Place External: drag/scale/rotate to choose starting pose, then click ‘✅ Place External Image’.")
+ return
+
+ if (not self.add_poly_active) and self._block_if_pending_segment("adding a polygon"):
+ return
+
+ if not self.add_poly_active:
+ if self.canvas.base_bgr is None:
+ QMessageBox.information(self, "Add Polygon", "Please select an image first.")
+ self._set_instruction("Click ‘Select Image’ to begin.")
+ return
+
+ # --- KEY CHANGE ---
+ # If there's no current layer OR the current layer is BASE -> make a NEW BASE layer (new color).
+ # If the current layer is EXTERNAL -> split/cut that external (preserve motion).
+ if (self.canvas.current_layer is None) or (not self.canvas.current_layer.is_external):
+ # New polygon on the base image => new layer with a fresh color
+ self.canvas.new_layer_from_source(
+ name=f"Layer {len(self.canvas.layers)+1}",
+ source_bgr=self.canvas.base_bgr,
+ is_external=False
+ )
+ else:
+ # Current layer is external: go into "draw polygon to cut external" mode
+ self.canvas.start_draw_polygon(preserve_motion=True)
+ # --- END KEY CHANGE ---
+
+ self.add_poly_active = True
+ self.btn_add_poly.setText("✅ Finish Polygon Selection")
+ self.status_label.setText("Drawing polygon…")
+ self._set_instruction("Polygon mode: Left-click to add points. Backspace = undo point. Right-click = finish. Esc = cancel.")
+ else:
+ # Finish current polygon selection
+ preserve = (self.canvas.current_layer is not None and
+ self.canvas.current_layer.pixmap_item is not None and
+ self.canvas.current_layer.is_external)
+ ok = self.canvas.finish_polygon(preserve_motion=preserve)
+ if not ok:
+ QMessageBox.information(self, "Polygon", "Need at least 3 points (keep adding).")
+ self._set_instruction("Keep adding polygon points (≥3). Right-click to finish.")
+ return
+ self.add_poly_active = False
+ self.btn_add_poly.setText("Add Polygon")
+ self.status_label.setText("Polygon ready.")
+ self._set_instruction(
+ "Drag to move, use corner circles to scale, top dot to rotate. "
+ "Adjust Hue if you like, then click ‘🎯 End Segment’ or Right Click to record a move."
+ )
+ def on_add_or_place_external(self):
+ # If we're already in "placing" mode, finalize the initial keyframe.
+ if self.placing_external and self.placing_layer is not None:
+ try:
+ # Lock the initial pose as keyframe #1
+ self.canvas.place_external_initial_keyframe(self.placing_layer)
+ # Make sure this layer stays selected
+ self.canvas.current_layer = self.placing_layer
+ # Draw the dashed preview line if relevant
+ self.canvas._ensure_preview_line(self.placing_layer)
+ finally:
+ self.placing_external = False
+ self.placing_layer = None
+ self.btn_add_external.setText("🖼️➕ Add External Image")
+ self.status_label.setText("External starting pose locked.")
+ self._set_instruction("Now drag/scale/rotate and click ‘🎯 End Segment’ to record movement.")
+ return
+
+ # Otherwise, begin adding a new external image.
+ if self._block_if_pending_segment("adding an external image"):
+ return
+ if self.canvas.base_bgr is None:
+ QMessageBox.information(self, "External", "Select a base image first.")
+ self._set_instruction("Click ‘Select Image’ to begin.")
+ return
+
+ path, _ = QFileDialog.getOpenFileName(
+ self, "Select external image", "",
+ "Images/Videos (*.png *.jpg *.jpeg *.bmp *.mp4 *.mov *.avi *.mkv)"
+ )
+ if not path:
+ self._set_instruction("External not chosen. You can Add External Image later.")
+ return
+
+ try:
+ raw = load_first_frame(path)
+ except Exception as e:
+ QMessageBox.critical(self, "Load", f"Failed to load external: {e}")
+ return
+
+ L = self.canvas.add_external_sprite_layer(raw) # no keyframe yet
+ if L is None:
+ QMessageBox.critical(self, "External", "Failed to create external layer.")
+ return
+
+ self.placing_external = True
+ self.placing_layer = L
+ self.canvas.current_layer = L # keep selection consistent
+ self.btn_add_external.setText("✅ Place External Image")
+ self.status_label.setText("Place external image.")
+ self._set_instruction("Place External: drag into view, scale with corner circles, rotate with top dot. Then click ‘✅ Place External Image’.")
+
+
+ def _on_canvas_polygon_finished(self, ok: bool):
+ if ok:
+ self.add_poly_active = False
+ self.btn_add_poly.setText("Add Polygon")
+ self.status_label.setText("Polygon ready.")
+ self._set_instruction(
+ "Drag to move, use corner circles to scale, top dot to rotate. "
+ "Adjust Hue if you like, then click ‘🎯 End Segment’ or Right Click to record a move."
+ )
+ else:
+ # keep your existing “need ≥3 points” behavior; nothing else to do here
+ pass
+
+ def on_hue_changed(self, val: int):
+ self.canvas.current_segment_hue_deg = float(val)
+ self.canvas._update_current_item_hue_preview()
+
+ def on_end_segment(self):
+ if self.placing_external:
+ QMessageBox.information(self, "Place External First",
+ "Please place the external image first (click ‘✅ Place External Image’).")
+ self._set_instruction("Place External: drag/scale/rotate to choose starting pose, then click ‘✅ Place External Image’.")
+ return
+ ok = self.canvas.end_segment_add_keyframe()
+ if ok:
+ n = len(self.canvas.current_layer.keyframes) if self.canvas.current_layer else 0
+ self.status_label.setText(f"Keyframe #{n} added.")
+ self._set_instruction(
+ "Segment added! Move again for the next leg, adjust Hue if you like, "
+ "then click ‘🎯 End Segment’ or Right Click to record a move."
+ )
+ else:
+ QMessageBox.information(self, "End Segment", "Nothing to record yet. Add/finish a polygon or add/place an external sprite first.")
+ self._set_instruction("Add a polygon (base/external) or place an external image, then drag and click ‘🎯 End Segment’.")
+
+ def on_undo(self):
+ if self.placing_external and self.placing_layer is not None:
+ L = self.placing_layer
+ if L.pixmap_item is not None: self.canvas.scene.removeItem(L.pixmap_item)
+ if L.outline_item is not None: self.canvas.scene.removeItem(L.outline_item)
+ for it in L.handle_items: self.canvas.scene.removeItem(it)
+ try:
+ idx = self.canvas.layers.index(L)
+ self.canvas.layers.pop(idx)
+ except ValueError:
+ pass
+ self.canvas.current_layer = self.canvas.layers[-1] if self.canvas.layers else None
+ self.placing_layer = None
+ self.placing_external = False
+ self.btn_add_external.setText("🖼️➕ Add External Image")
+ self.status_label.setText("External placement canceled.")
+ self._set_instruction("External placement canceled. Add External Image again or continue editing.")
+ return
+
+ if self.canvas.undo():
+ self.status_label.setText("Undo applied.")
+ self._set_instruction("Undone. Continue editing, or click ‘🎯 End Segment’ to record movement.")
+ else:
+ self.status_label.setText("Nothing to undo.")
+ self._set_instruction("Nothing to undo. Drag/scale/rotate and click ‘🎯 End Segment’, or add new shapes.")
+
+ def _sample_keyframes_uniform(self, keyframes: List[Keyframe], T: int):
+ K = len(keyframes); assert K >= 2
+ segs = K - 1
+ u = np.linspace(0.0, float(segs), T, dtype=np.float32)
+ seg_idx = np.minimum(np.floor(u).astype(int), segs - 1)
+ t = u - seg_idx
+ k0 = np.array([[keyframes[i].pos[0], keyframes[i].pos[1], keyframes[i].scale, keyframes[i].rot_deg] for i in seg_idx], dtype=np.float32)
+ k1 = np.array([[keyframes[i+1].pos[0], keyframes[i+1].pos[1], keyframes[i+1].scale, keyframes[i+1].rot_deg] for i in seg_idx], dtype=np.float32)
+ pos0 = k0[:, :2]; pos1 = k1[:, :2]
+ s0 = np.maximum(1e-6, k0[:, 2]); s1 = np.maximum(1e-6, k1[:, 2])
+ r0 = k0[:, 3]; r1 = k1[:, 3]
+ pos = (1 - t)[:, None] * pos0 + t[:, None] * pos1
+ scl = np.exp((1 - t) * np.log(s0) + t * np.log(s1))
+ rot = (1 - t) * r0 + t * r1
+ return pos.astype(np.float32), scl.astype(np.float32), rot.astype(np.float32)
+
+ def on_play_demo(self):
+ if self.canvas.base_bgr is None:
+ QMessageBox.information(self, "Play Demo", "Select an image first.")
+ self._set_instruction("Click ‘Select Image’ to begin.")
+ return
+ has_segments = any((L.polygon_xy is not None and len(L.keyframes) >= 2) for L in self.canvas.layers)
+ if not has_segments:
+ QMessageBox.information(self, "Play Demo", "No motion segments yet. Drag something and click ‘🎯 End Segment’ at least once.")
+ self._set_instruction("Create at least one movement: drag/scale/rotate then click ‘🎯 End Segment’.")
+ return
+ fps = int(self.spn_fps.value())
+ T_total = int(self.spn_total_frames.value())
+ self.canvas.play_demo(fps=fps, T_total=T_total)
+ self._set_instruction("Playing demo… When it ends, you’ll return to the editor. Tweak and play again, or 💾 Save.")
+
+ def on_new(self):
+ if self._block_if_pending_segment("starting a new project"):
+ return
+ self.canvas.scene.clear()
+ self.canvas.layers.clear()
+ self.canvas.current_layer = None
+ self.canvas.base_bgr = None
+ self.canvas.base_preview_bgr = None
+ self.canvas.base_item = None
+ self.add_poly_active = False
+ self.btn_add_poly.setText("Add Polygon")
+ self.placing_external = False; self.placing_layer = None
+ self.btn_add_external.setText("🖼️➕ Add External Image")
+ if hasattr(self, "txt_prompt"):
+ self.txt_prompt.clear()
+ self.status_label.setText("Ready")
+ self._set_instruction("New project. Click ‘Select Image’ to begin.")
+ self.on_select_base()
+
+ def on_save(self):
+ if self._block_if_pending_segment("saving"):
+ return
+ if self.canvas.base_bgr is None or not self.canvas.layers:
+ QMessageBox.information(self, "Save", "Load an image and add at least one polygon/sprite first.")
+ self._set_instruction("Add a polygon (base/external), record segments (🎯 End Segment), then Save.")
+ return
+
+ # If any layer has exactly one keyframe, auto-add the current pose as a second keyframe
+ for L in self.canvas.layers:
+ if L.pixmap_item and L.polygon_xy is not None and len(L.keyframes) == 1:
+ self.canvas.current_layer = L
+ self.canvas.end_segment_add_keyframe()
+
+ # 1) Pick a parent directory
+ base_dir = QtWidgets.QFileDialog.getExistingDirectory(
+ self, "Select output directory", ""
+ )
+ if not base_dir:
+ self._set_instruction("Save canceled. You can keep editing or try ▶️ Play Demo.")
+ return
+
+ # 2) Ask for a subdirectory name
+ subdir_name, ok = QtWidgets.QInputDialog.getText(
+ self, "Subfolder Name", "Create a new subfolder in the selected directory:"
+ )
+ if not ok or not subdir_name.strip():
+ self._set_instruction("Save canceled (no subfolder name).")
+ return
+ subdir_name = subdir_name.strip()
+
+ final_dir = os.path.join(base_dir, subdir_name)
+ if os.path.exists(final_dir):
+ resp = QMessageBox.question(
+ self, "Folder exists",
+ f"'{subdir_name}' already exists in the selected directory.\n"
+ f"Use it and overwrite files?",
+ QMessageBox.Yes | QMessageBox.No, QMessageBox.No
+ )
+ if resp != QMessageBox.Yes:
+ self._set_instruction("Save canceled. Choose another name or directory next time.")
+ return
+ else:
+ try:
+ os.makedirs(final_dir, exist_ok=True)
+ except Exception as e:
+ QMessageBox.critical(self, "Save", f"Failed to create folder:\n{e}")
+ return
+
+ try:
+ prompt_text = self.txt_prompt.toPlainText()
+ except Exception:
+ prompt_text = ""
+ try:
+ with open(os.path.join(final_dir, "prompt.txt"), "w", encoding="utf-8") as f:
+ f.write(prompt_text)
+ except Exception as e:
+ # Non-fatal: continue saving the rest if prompt write fails
+ print(f"[warn] Failed to write prompt.txt: {e}")
+
+ # Output paths
+ first_frame_path = os.path.join(final_dir, "first_frame.png")
+ motion_path = os.path.join(final_dir, "motion_signal.mp4")
+ mask_path = os.path.join(final_dir, "mask.mp4")
+ base_title = subdir_name # for optional numpy save below
+ npy_path = os.path.join(final_dir, f"{base_title}_polygons.npy")
+
+ fps = int(self.spn_fps.value())
+ T_total = int(self.spn_total_frames.value())
+
+ # Build background (inpaint base regions belonging to non-external layers)
+ H, W = self.canvas.base_bgr.shape[:2]
+ total_mask = np.zeros((H, W), dtype=bool)
+ for L in self.canvas.layers:
+ if L.polygon_xy is None:
+ continue
+ if L.is_external:
+ continue
+ poly0 = L.polygon_xy.astype(np.int32)
+ m = np.zeros((H, W), dtype=np.uint8)
+ cv2.fillPoly(m, [poly0], 255)
+ total_mask |= (m > 0)
+ background = inpaint_background(self.canvas.base_bgr, total_mask)
+
+ # Collect animated frames for each layer (with hue crossfade as before)
+ all_layer_frames = []
+ layer_polys = [] # kept for the optional numpy block below
+ for L in self.canvas.layers:
+ if L.polygon_xy is None or len(L.keyframes) < 2:
+ continue
+
+ def sample_keyframes_uniform_with_seg(keyframes: List[Keyframe], T: int):
+ K = len(keyframes); assert K >= 1
+ if K == 1:
+ pos = np.repeat(keyframes[0].pos[None, :], T, axis=0).astype(np.float32)
+ scl = np.full((T,), keyframes[0].scale, dtype=np.float32)
+ rot = np.full((T,), keyframes[0].rot_deg, dtype=np.float32)
+ seg_idx = np.zeros((T,), dtype=np.int32)
+ t = np.zeros((T,), dtype=np.float32)
+ return pos, scl, rot, seg_idx, t
+ segs = K - 1
+ u = np.linspace(0.0, float(segs), T, dtype=np.float32)
+ seg_idx = np.minimum(np.floor(u).astype(int), segs - 1)
+ t = u - seg_idx
+ k0 = np.array([[keyframes[i].pos[0], keyframes[i].pos[1], keyframes[i].scale, keyframes[i].rot_deg] for i in seg_idx], dtype=np.float32)
+ k1 = np.array([[keyframes[i+1].pos[0], keyframes[i+1].pos[1], keyframes[i+1].scale, keyframes[i+1].rot_deg] for i in seg_idx], dtype=np.float32)
+ pos0 = k0[:, :2]; pos1 = k1[:, :2]
+ s0 = np.maximum(1e-6, k0[:, 2]); s1 = np.maximum(1e-6, k1[:, 2])
+ r0 = k0[:, 3]; r1 = k1[:, 3]
+ pos = (1 - t)[:, None] * pos0 + t[:, None] * pos1
+ scl = np.exp((1 - t) * np.log(s0) + t * np.log(s1))
+ rot = (1 - t) * r0 + t * r1
+ return pos.astype(np.float32), scl.astype(np.float32), rot.astype(np.float32), seg_idx, t
+
+ path_xy, scales, rots, seg_idx, t = sample_keyframes_uniform_with_seg(L.keyframes, T_total)
+ origin_xy = L.origin_local_xy if L.origin_local_xy is not None else L.polygon_xy.mean(axis=0)
+
+ # Precompute one animation per keyframe hue
+ K = len(L.keyframes)
+ hue_values = [L.keyframes[k].hue_deg for k in range(K)]
+ hue_to_frames: Dict[int, List[np.ndarray]] = {}
+ polys_for_layer = None
+ for k in range(K):
+ bgr_h = apply_hue_shift_bgr(L.source_bgr, hue_values[k])
+ frames_h, polys = animate_polygon(
+ bgr_h, L.polygon_xy, path_xy, scales, rots,
+ interp=cv2.INTER_LINEAR, origin_xy=origin_xy
+ )
+ hue_to_frames[k] = frames_h
+ if polys_for_layer is None: # same polys for all hues
+ polys_for_layer = np.array(polys, dtype=np.float32)
+ if polys_for_layer is not None:
+ layer_polys.append(polys_for_layer)
+
+ # Mix per frame using seg_idx / t
+ frames_rgba = []
+ for i in range(T_total):
+ s = int(seg_idx[i])
+ w = float(t[i])
+ A = hue_to_frames[s][i].astype(np.float32)
+ B = hue_to_frames[s+1][i].astype(np.float32)
+ mix = (1.0 - w) * A + w * B
+ frames_rgba.append(np.clip(mix, 0, 255).astype(np.uint8))
+ all_layer_frames.append(frames_rgba)
+
+ if not all_layer_frames:
+ QMessageBox.information(self, "Save", "No motion segments found. Add keyframes with ‘🎯 End Segment’.")
+ self._set_instruction("Record at least one segment on a layer, then Save.")
+ return
+
+ frames_out = composite_frames(background, all_layer_frames)
+
+ # Build mask frames (union of alpha across layers per frame)
+ mask_frames = []
+ for t in range(T_total):
+ m = np.zeros((H, W), dtype=np.uint16)
+ for Lframes in all_layer_frames:
+ m += Lframes[t][:, :, 3].astype(np.uint16)
+ m = np.clip(m, 0, 255).astype(np.uint8)
+ mask_frames.append(m)
+
+ # --- Actual saving ---
+ try:
+ # first_frame.png (copy of the base image used for saving)
+ cv2.imwrite(first_frame_path, self.canvas.base_bgr)
+
+ # motion_signal.mp4 = composited warped video
+ save_video_mp4(frames_out, motion_path, fps=fps)
+
+ # mask.mp4 = grayscale mask video
+ save_video_mp4([cv2.cvtColor(m, cv2.COLOR_GRAY2BGR) for m in mask_frames], mask_path, fps=fps)
+
+ # Optional: polygons.npy — disabled by default
+ if False:
+ # Pad and save polygons
+ Vmax = 0
+ for P in layer_polys:
+ if P.size:
+ Vmax = max(Vmax, P.shape[1])
+
+ def pad_poly(P: np.ndarray, Vmax_: int) -> np.ndarray:
+ if P.size == 0:
+ return np.zeros((T_total, Vmax_, 2), dtype=np.float32)
+ T_, V, _ = P.shape
+ out = np.zeros((T_, Vmax_, 2), dtype=np.float32)
+ out[:, :V, :] = P
+ if V > 0:
+ out[:, V:, :] = P[:, V-1:V, :]
+ return out
+
+ polys_uniform = np.stack([pad_poly(P, Vmax) for P in layer_polys], axis=0)
+ np.save(npy_path, polys_uniform)
+
+ except Exception as e:
+ QMessageBox.critical(self, "Save", f"Failed to save:\n{e}")
+ return
+
+ QMessageBox.information(self, "Saved", f"Saved to:\n{final_dir}")
+ self._set_instruction("Saved! You can keep editing, play demo again, or start a New project.")
+
+
+# ------------------------------
+# Entry
+# ------------------------------
+
+def main():
+ if sys.version_info < (3, 8):
+ print("[Warning] PySide6 officially supports Python 3.8+. You're on %d.%d." % (sys.version_info.major, sys.version_info.minor))
+ app = QApplication(sys.argv)
+ w = MainWindow()
+ w.show()
+ sys.exit(app.exec())
+
+if __name__ == "__main__":
+ main()
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..b703dc224f35a4ebc67a07bc3cde4cd1e0bab860
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..035bf2dd80ff98135e1329fe0129808dcd275005
--- /dev/null
+++ b/README.md
@@ -0,0 +1,169 @@
+
Time-to-Move
+
Training-Free Motion-Controlled Video Generation via Dual-Clock Denoising
+
+
+
+## Table of Contents
+
+- [Inference](#inference)
+ - [Dual Clock Denoising Guide](#dual-clock-denoising)
+ - [Wan](#wan)
+ - [CogVideoX](#cogvideox)
+ - [Stable Video Diffusion](#stable-video-diffusion)
+- [Generate Your Own Cut-and-Drag Examples](#generate-your-own-cut-and-drag-examples)
+ - [GUI guide](GUIs/README.md)
+- [TODO](#todo)
+- [BibTeX](#bibtex)
+
+
+## Inference
+
+**Time-to-Move (TTM)** is a plug-and-play technique that can be integrated into any image-to-video diffusion model.
+We provide implementations for **Wan 2.2**, **CogVideoX**, and **Stable Video Diffusion (SVD)**.
+As expected, the stronger the base model, the better the resulting videos.
+Adapting TTM to new models and pipelines is straightforward and can typically be done in just a few hours.
+We **recommend using Wan**, which generally produces higher‑quality results and adheres more faithfully to user‑provided motion signals.
+
+
+For each model, you can use the [included examples](./examples/) or create your own as described in
+[Generate Your Own Cut-and-Drag Examples](#generate-your-own-cut-and-drag-examples).
+
+### Dual Clock Denoising
+TTM depends on two hyperparameters that start different regions at different noise depths. In practice, we do not pass `tweak` and `tstrong` as raw timesteps. Instead we pass `tweak-index` and `tstrong-index`, which indicate the iteration at which each denoising phase begins out of the total `num_inference_steps` (50 for all models).
+Constraints: `0 ≤ tweak-index ≤ tstrong-index ≤ num_inference_steps`.
+
+* **tweak-index** — when the denoising process **outside the mask** begins.
+ - Too low: scene deformations, object duplication, or unintended camera motion.
+ - Too high: regions outside the mask look static (e.g., non-moving backgrounds).
+* **tstrong-index** — when the denoising process **within the mask** begins. In our experience, this depends on mask size and mask quality.
+ - Too low: object may drift from the intended path.
+ - Too high: object may look rigid or over-constrained.
+
+
+### Wan
+To set up the environment for running Wan 2.2, follow the installation instructions in the official [Wan 2.2 repository](https://github.com/Wan-Video/Wan2.2). Our implementation builds on the [🤗 Diffusers Wan I2V pipeline](https://github.com/huggingface/diffusers/blob/345864eb852b528fd1f4b6ad087fa06e0470006b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py)
+adapted for TTM using the I2V 14B backbone.
+
+#### Run inference (using the included Wan examples):
+```bash
+python run_wan.py \
+ --input-path "./examples/cutdrag_wan_Monkey" \
+ --output-path "./outputs/wan_monkey.mp4" \
+ --tweak-index 3 \
+ --tstrong-index 7
+```
+
+#### Good starting points:
+* Cut-and-Drag: tweak-index=3, tstrong-index=7
+* Camera control: tweak-index=2, tstrong-index=5
+
+
+
+
+ CogVideoX
+
+ To set up the environment for running CogVideoX, follow the installation instructions in the official [CogVideoX repository](https://github.com/zai-org/CogVideo).
+ Our implementation builds on the [🤗 Diffusers CogVideoX I2V pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py), which we adapt for Time-to-Move (TTM) using the CogVideoX-I2V 5B backbone.
+
+
+#### Run inference (on the included 49-frame CogVideoX example):
+```bash
+python run_cog.py \
+ --input-path "./examples/cutdrag_cog_Monkey" \
+ --output-path "./outputs/cog_monkey.mp4" \
+ --tweak-index 4 \
+ --tstrong-index 9
+```
+
+
+
+
+
+ Stable Video Diffusion
+
+
+To set up the environment for running SVD, follow the installation instructions in the official [SVD repository](https://github.com/Stability-AI/generative-models).
+Our implementation builds on the [🤗 Diffusers SVD I2V pipeline](https://github.com/huggingface/diffusers/blob/8abc7aeb715c0149ee0a9982b2d608ce97f55215/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py#L147
+), which we adapt for Time-to-Move (TTM).
+
+#### To run inference (on the included 21-frame SVD example):
+```bash
+python run_svd.py \
+ --input-path "./examples/cutdrag_svd_Fish" \
+ --output-path "./outputs/svd_fish.mp4" \
+ --tweak-index 16 \
+ --tstrong-index 21
+```
+
+
+
+## Generate Your Own Cut-and-Drag Examples
+We provide an easy-to-use GUI for creating cut-and-drag examples that can later be used for video generation in **Time-to-Move**. We recommend reading the [GUI guide](GUIs/README.md) before using it.
+
+
+
+
+
+To get started quickly, create a new environment and run:
+```bash
+pip install PySide6 opencv-python numpy imageio imageio-ffmpeg
+python GUIs/cut_and_drag.py
+```
+
+
+### TODO 🛠️
+
+- [x] Wan 2.2 run code
+- [x] CogVideoX run code
+- [x] SVD run code
+- [x] Cut-and-Drag examples
+- [x] Camera-control examples
+- [x] Cut-and-Drag GUI
+- [x] Cut-and-Drag GUI guide
+- [ ] Evaluation code
+
+
+## BibTeX
+```
+@misc{singer2025timetomovetrainingfreemotioncontrolled,
+ title={Time-to-Move: Training-Free Motion Controlled Video Generation via Dual-Clock Denoising},
+ author={Assaf Singer and Noam Rotstein and Amir Mann and Ron Kimmel and Or Litany},
+ year={2025},
+ eprint={2511.08633},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2511.08633},
+}
+```
diff --git a/assets/gui.png b/assets/gui.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9df4c72153e94c2c8a5fc1d28ebec0fdf42494d
--- /dev/null
+++ b/assets/gui.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6dfe3d202f383a64ff9c8756868f4f6a6d724b5a7a8ff0e5c93c8c4f546d3e43
+size 223909
diff --git a/assets/logo_arxiv.svg b/assets/logo_arxiv.svg
new file mode 100644
index 0000000000000000000000000000000000000000..2a5f1c5b1521a32694d488f02f273963e7211b07
--- /dev/null
+++ b/assets/logo_arxiv.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/assets/logo_page.svg b/assets/logo_page.svg
new file mode 100644
index 0000000000000000000000000000000000000000..896ce555e8533ba0d45c4a36092e37d7fad6c619
--- /dev/null
+++ b/assets/logo_page.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/assets/logo_paper.svg b/assets/logo_paper.svg
new file mode 100644
index 0000000000000000000000000000000000000000..334c1dd8b17548dd29b286bf2ea6c588edfd6a47
--- /dev/null
+++ b/assets/logo_paper.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/assets/teaser.gif b/assets/teaser.gif
new file mode 100644
index 0000000000000000000000000000000000000000..17f63b761607ad1132ca4d34fc839abf1881f97e
--- /dev/null
+++ b/assets/teaser.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f96d0a7dfe267413e93e45f6c5f5fffb3329a81e3c95934674d5e8b9928e9b5
+size 3136688
diff --git a/docs/HUGGINGFACE.md b/docs/HUGGINGFACE.md
new file mode 100644
index 0000000000000000000000000000000000000000..54ee975db9bfe897f7810295b0bc3b7a5fd08ef2
--- /dev/null
+++ b/docs/HUGGINGFACE.md
@@ -0,0 +1,137 @@
+# Hosting Time-to-Move on Hugging Face
+
+This guide explains how to mirror the Time-to-Move (TTM) codebase on the 🤗 Hub and how to expose an interactive demo through a Space. It assumes you already read the main `README.md`, understand how to run `run_wan.py`, and have access to Wan 2.2 weights through your Hugging Face account.
+
+---
+
+## 1. Prerequisites
+
+- Hugging Face account with access to the Wan 2.2 Image-to-Video model (`Wan-AI/Wan2.2-I2V-A14B-Diffusers` at the time of writing).
+- Local environment with Git, Git LFS, Python 3.10+, and the `huggingface_hub` CLI.
+- GPU-backed hardware both locally (for testing) and on Spaces (A100 or A10 is strongly recommended; CPU-only tiers are too slow for Wan 2.2).
+- Optional: organization namespace on the Hugging Face Hub (recommended if you want to publish under a team/org).
+
+Authenticate once locally (this stores a token in `~/.huggingface`):
+
+```bash
+huggingface-cli login
+git lfs install
+```
+
+---
+
+## 2. Publish the code as a model repository
+
+1. **Create an empty repo on the Hub.** Example:
+
+ ```bash
+ huggingface-cli repo create time-to-move/wan-ttm --type=model --yes
+ git clone https://huggingface.co/time-to-move/wan-ttm
+ cd wan-ttm
+ ```
+
+2. **Copy the TTM sources.** From the project root, copy the files that users need to reproduce inference:
+
+ ```bash
+ rsync -av \
+ --exclude ".git/" \
+ --exclude "outputs/" \
+ /path/to/TTM/ \
+ /path/to/wan-ttm/
+ ```
+
+ Make sure `pipelines/`, `run_wan.py`, `run_cog.py`, `run_svd.py`, `examples/`, and the new `huggingface_space/` folder are included. Track large binary assets:
+
+ ```bash
+ git lfs track "*.mp4" "*.png" "*.gif"
+ git add .gitattributes
+ ```
+
+3. **Add a model card.** Reuse the main `README.md` or create a shorter version describing:
+ - What Time-to-Move does.
+ - How to run `run_wan.py` with the `motion_signal` + `mask`.
+ - Which base model checkpoint the repo expects (Wan 2.2 I2V A14B).
+
+4. **Push to the Hub.**
+
+ ```bash
+ git add .
+ git commit -m "Initial commit of Time-to-Move Wan implementation"
+ git push
+ ```
+
+Users can now do:
+
+```python
+from huggingface_hub import snapshot_download
+snapshot_download("time-to-move/wan-ttm")
+```
+
+---
+
+## 3. Prepare a Hugging Face Space (Gradio)
+
+This repository now contains `huggingface_space/`, a ready-to-use Space template:
+
+```
+huggingface_space/
+├── README.md # Quickstart instructions
+├── app.py # Gradio UI (loads Wan + Time-to-Move logic)
+└── requirements.txt # Runtime dependencies
+```
+
+### 3.1 Create the Space
+
+```bash
+huggingface-cli repo create time-to-move/wan-ttm-demo --type=space --sdk=gradio --yes
+git clone https://huggingface.co/spaces/time-to-move/wan-ttm-demo
+cd wan-ttm-demo
+```
+
+Copy everything from `huggingface_space/` into the Space repository (or keep the whole repo and set the Space’s working directory accordingly). Commit and push.
+
+### 3.2 Configure hardware and secrets
+
+- **Hardware:** Select an A100 (preferred) or A10 GPU runtime in the Space settings. Wan 2.2 is too heavy for CPUs.
+- **WAN_MODEL_ID:** If you mirrored Wan 2.2 into your organization, set the environment variable to point to it. Otherwise leave the default (`Wan-AI/Wan2.2-I2V-A14B-Diffusers`).
+- **HF_TOKEN / WAN_ACCESS_TOKEN:** Add a Space secret only if the Wan checkpoint is private. The Gradio app reads from `HF_TOKEN` automatically when calling `from_pretrained`.
+- **PYTORCH_CUDA_ALLOC_CONF:** Recommended value `expandable_segments:True` to reduce CUDA fragmentation.
+
+### 3.3 How the app works
+
+`huggingface_space/app.py` exposes:
+
+- A dropdown of the pre-packaged `examples/cutdrag_wan_*` prompts.
+- Optional custom uploads (`first_frame`, `mask.mp4`, `motion_signal.mp4`) following the README workflow.
+- Sliders for `tweak-index`, `tstrong-index`, guidance scale, seed, etc.
+- Live status messages and a generated MP4 preview using `diffusers.utils.export_to_video`.
+
+The UI lazily loads the `WanImageToVideoTTMPipeline` with tiling/slicing enabled to reduce VRAM usage. All preprocessing matches the logic in `run_wan.py` (the same `compute_hw_from_area` helper is reused).
+
+If you need to customize the experience (e.g., restrict to certain prompts, enforce shorter sequences), edit `huggingface_space/app.py` before pushing.
+
+---
+
+## 4. Testing checklist
+
+1. **Local dry-run.**
+ ```bash
+ pip install -r huggingface_space/requirements.txt
+ WAN_MODEL_ID=Wan-AI/Wan2.2-I2V-A14B-Diffusers \
+ python huggingface_space/app.py
+ ```
+ Ensure you can generate at least one of the bundled examples.
+
+2. **Space smoke test.**
+ - Open the deployed Space.
+ - Run the default example (`cutdrag_wan_Monkey`) and confirm you receive a video in ~2–3 minutes on A100 hardware.
+ - Optionally upload a small custom mask/video pair and verify that `tweak-index`/`tstrong-index` are honored.
+
+3. **Monitor logs.** Use the Space “Logs” tab to confirm:
+ - The pipeline downloads from the expected `WAN_MODEL_ID`.
+ - VRAM usage stays within the selected hardware tier.
+
+4. **Freeze dependencies.** When satisfied, tag the Space (`v1`, `demo`) so users know which TTM commit it matches.
+
+You now have both a **model repository** (for anyone to clone/run) and a **public Space** for live demos. Feel free to adapt the instructions for the CogVideoX or Stable Video Diffusion pipelines if you plan to expose them as well; start by duplicating the provided Space template and swapping out `run_wan.py` for the relevant runner.
+
diff --git a/examples/camcontrol_Bridge/first_frame.png b/examples/camcontrol_Bridge/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..8d79b6b332749ae76a717758200b53105fa36489
--- /dev/null
+++ b/examples/camcontrol_Bridge/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6ee141276b8b202798b6bc727ca46a8f4b6202739464121108dc1304c3e40c10
+size 705982
diff --git a/examples/camcontrol_Bridge/mask.mp4 b/examples/camcontrol_Bridge/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3b417c610f60f57c59260eefd8848dd55f7b2f09
--- /dev/null
+++ b/examples/camcontrol_Bridge/mask.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5f4670ba935c0d5dd9e5f1b3632c6b1c8ca45a2e2af80db623e953171c037c77
+size 453479
diff --git a/examples/camcontrol_Bridge/motion_signal.mp4 b/examples/camcontrol_Bridge/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8291ca4d2ab18576577b24eb9d34b4cf7e0ed336
--- /dev/null
+++ b/examples/camcontrol_Bridge/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c92b189953a948a73e0803a53a6e433e752de3bf0b57afd894a9329144310ce
+size 1725802
diff --git a/examples/camcontrol_Bridge/prompt.txt b/examples/camcontrol_Bridge/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d2b14390ffeae337c7b29cfb94b3a06e9cfb5cd0
--- /dev/null
+++ b/examples/camcontrol_Bridge/prompt.txt
@@ -0,0 +1 @@
+A stone bridge arches over a narrow river winding through a steep canyon. The camera flies low along the river, gliding beneath the bridge, where water sparkles and echoes against the rock walls. Slowly, the view rises upward, revealing vast green forests blanketing the surrounding hills and valleys. Warm sunlight filters through the trees, highlighting the lush textures of the landscape and adding depth to the sweeping forest view.
\ No newline at end of file
diff --git a/examples/camcontrol_ConcertCrowd/first_frame.png b/examples/camcontrol_ConcertCrowd/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..a7b8ddab242da2286467c0f2ff06ee9edfe0eeb8
--- /dev/null
+++ b/examples/camcontrol_ConcertCrowd/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf02fc1b2f175c09d0e2f2af93089dc55dfbb20ea7c7c48971821912b431e181
+size 696431
diff --git a/examples/camcontrol_ConcertCrowd/mask.mp4 b/examples/camcontrol_ConcertCrowd/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..94f7fbd40d43daa39f2463951309a7e2f3c63293
--- /dev/null
+++ b/examples/camcontrol_ConcertCrowd/mask.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29652ba7927e0dea5744dc107b334d63cab72e0ebeb502ecb4bdcbf07d7cfcc4
+size 989549
diff --git a/examples/camcontrol_ConcertCrowd/motion_signal.mp4 b/examples/camcontrol_ConcertCrowd/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0385ea14d7ab05327ff6647f3c08d47205f1da4d
--- /dev/null
+++ b/examples/camcontrol_ConcertCrowd/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4dafbcc4a0cdbceef618d8d179004f2093a68db370cc9a28fdbcc970def9b26a
+size 4568859
diff --git a/examples/camcontrol_ConcertCrowd/prompt.txt b/examples/camcontrol_ConcertCrowd/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9a4709aee5045dae36bbbffc0f029d61673b6d1a
--- /dev/null
+++ b/examples/camcontrol_ConcertCrowd/prompt.txt
@@ -0,0 +1 @@
+A massive crowd fills the arena, their energy palpable as they sway and cheer under the dazzling lights. Confetti flutters down from above, adding to the electric atmosphere. The audience, a sea of raised hands and excited faces, pulses with anticipation as the music builds. Security personnel stand at the forefront, ensuring safety while the crowd's enthusiasm grows. The stage is set for an unforgettable performance, with the audience fully immersed in the moment, ready to sing along and dance to the rhythm of the night.
\ No newline at end of file
diff --git a/examples/camcontrol_ConcertStage/first_frame.png b/examples/camcontrol_ConcertStage/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..004d42e114987d38db1a8701c0690493af5f100a
--- /dev/null
+++ b/examples/camcontrol_ConcertStage/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f370134f1d4fec17ebd90ce1ba970c373127ac38a47ede8741e9c6edd00c4540
+size 559447
diff --git a/examples/camcontrol_ConcertStage/mask.mp4 b/examples/camcontrol_ConcertStage/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..f094b7b9d49326b86171ca39e27378af60430998
--- /dev/null
+++ b/examples/camcontrol_ConcertStage/mask.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:31c2037732bacd73d4d8d723270e042c1898eb4a5021bb9e16f8b08ff30a5686
+size 845639
diff --git a/examples/camcontrol_ConcertStage/motion_signal.mp4 b/examples/camcontrol_ConcertStage/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..555ae6900dec963181ac0691b5763c06d7a5a402
--- /dev/null
+++ b/examples/camcontrol_ConcertStage/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b6b5b4e3aa1e4b3f719a800a1bdc7ab2f4aab18943011d06b0264ed44bf4504
+size 2643188
diff --git a/examples/camcontrol_ConcertStage/prompt.txt b/examples/camcontrol_ConcertStage/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5dada3b12acebae0bb14dfee5a2ce3eacdbcc141
--- /dev/null
+++ b/examples/camcontrol_ConcertStage/prompt.txt
@@ -0,0 +1 @@
+The concert arena is electrified with energy as a band performs on stage, bathed in vibrant blue and purple lights. Flames shoot up dramatically on either side, adding to the intensity of the performance. The crowd is a sea of raised hands, swaying and cheering in unison, completely immersed in the music. In the foreground, a fan with an ecstatic expression captures the moment on their phone, while others around them shout and sing along. The atmosphere is charged with excitement and the pulsating rhythm of the music reverberates through the venue.
\ No newline at end of file
diff --git a/examples/camcontrol_RiverOcean/first_frame.png b/examples/camcontrol_RiverOcean/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..23feb9112b8b11e6c94377db4641f5fad011446f
--- /dev/null
+++ b/examples/camcontrol_RiverOcean/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8a931afdfebcf65dcf04a098146fb8922544eaea0fbea725bf8ee324efa92b5
+size 668616
diff --git a/examples/camcontrol_RiverOcean/mask.mp4 b/examples/camcontrol_RiverOcean/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..44f2d44b20b9cee456162ae6286a0627fec29ac9
--- /dev/null
+++ b/examples/camcontrol_RiverOcean/mask.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:993427e110aaf4ef4b319f7d35c27ccbab2ec42a90916ecaab508179cac3d426
+size 528205
diff --git a/examples/camcontrol_RiverOcean/motion_signal.mp4 b/examples/camcontrol_RiverOcean/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..5128f4f724644ceb4600bcf16f4802af407a05d8
--- /dev/null
+++ b/examples/camcontrol_RiverOcean/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:189edb6a3e5d5491c4849e12b1acaf994a9978b8004dc6cc5f9449b34e28752b
+size 2424199
diff --git a/examples/camcontrol_RiverOcean/prompt.txt b/examples/camcontrol_RiverOcean/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cb271b3871a44adb952fce9eec5b7ac58af40dcd
--- /dev/null
+++ b/examples/camcontrol_RiverOcean/prompt.txt
@@ -0,0 +1 @@
+A serene river winds its way through lush greenery, bordered by dense forests and sandy shores, before gracefully merging with the vast ocean. The gentle waves of the ocean lap against the rocky cliffs, creating a harmonious blend of land and sea. The sun casts a warm glow over the landscape, highlighting the vibrant colors of the foliage and the shimmering water. As the scene unfolds, birds soar above, and the gentle breeze rustles the leaves, adding a sense of tranquility to this picturesque meeting of river and ocean.
\ No newline at end of file
diff --git a/examples/camcontrol_SpiderMan/first_frame.png b/examples/camcontrol_SpiderMan/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..91879246d8721633516580eb75f56db5f66d025d
--- /dev/null
+++ b/examples/camcontrol_SpiderMan/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b5d3491afe98d04b32fba4b9b95301ebb3f787a96c0125f09c7e4ef14084d9e0
+size 550720
diff --git a/examples/camcontrol_SpiderMan/mask.mp4 b/examples/camcontrol_SpiderMan/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e4115805ded182720cfeb3222a7ab9631170a021
--- /dev/null
+++ b/examples/camcontrol_SpiderMan/mask.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c65b3cd42cbb622b702ab740b35546f342b7a9e4bf611b76ad9399f21b45f1a0
+size 873522
diff --git a/examples/camcontrol_SpiderMan/motion_signal.mp4 b/examples/camcontrol_SpiderMan/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9012a8777e05f8e7f8731cad4421765ca237a798
--- /dev/null
+++ b/examples/camcontrol_SpiderMan/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d2c98573f7a9c8b5716dcfc19b51a0193dae0d8aa28dfed23f2b62440321f90f
+size 1821884
diff --git a/examples/camcontrol_SpiderMan/prompt.txt b/examples/camcontrol_SpiderMan/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..91b1711f2c504022a15dc65b76a451bbb4d04804
--- /dev/null
+++ b/examples/camcontrol_SpiderMan/prompt.txt
@@ -0,0 +1 @@
+A superhero in a red and blue suit swings gracefully through the towering skyscrapers of a bustling city. The sun sets in the distance, casting a warm golden glow over the urban landscape. As he releases his web, he flips through the air with agility, preparing to latch onto another building. The streets below are filled with the hustle and bustle of city life, while the hero moves effortlessly above, embodying a sense of freedom and adventure.
\ No newline at end of file
diff --git a/examples/camcontrol_Volcano/first_frame.png b/examples/camcontrol_Volcano/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f22c86b4ab42088c2856cadfe3bf6f672bc8af7
--- /dev/null
+++ b/examples/camcontrol_Volcano/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d668c88c9c081324cdcb5aeb74f9633759ef44ced0d80e587c033420ef61c5da
+size 550142
diff --git a/examples/camcontrol_Volcano/mask.mp4 b/examples/camcontrol_Volcano/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..813aab3d8c7bdb0ccf4c0f531e392702e6e0eb28
--- /dev/null
+++ b/examples/camcontrol_Volcano/mask.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0cb0974c3b9c055dfdded4b1081c65c068af133755d3dc0ce763e088fcbfebd8
+size 612213
diff --git a/examples/camcontrol_Volcano/motion_signal.mp4 b/examples/camcontrol_Volcano/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a6a6c7119d7c6fb8567bbdda7552650a75eebf67
--- /dev/null
+++ b/examples/camcontrol_Volcano/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:418fe3ec7394684193be49b6ecff9f5e4228eb667eb5ffa5a0050d8e260fd95a
+size 1510493
diff --git a/examples/camcontrol_Volcano/prompt.txt b/examples/camcontrol_Volcano/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..81eea843b24a05e87828c9d4cedafc4c0daf98a6
--- /dev/null
+++ b/examples/camcontrol_Volcano/prompt.txt
@@ -0,0 +1 @@
+Under a starry night sky illuminated by the ethereal glow of the aurora borealis, a volcano erupts with fierce intensity. Molten lava spews high into the air, casting a fiery orange glow against the darkened landscape. The lava cascades down the sides of the volcano, forming glowing rivers that snake across the rugged terrain. Ash and smoke billow upwards, mingling with the vibrant colors of the northern lights, creating a dramatic and mesmerizing spectacle. The scene is both awe-inspiring and formidable, capturing the raw power and beauty of nature in its most elemental form.
\ No newline at end of file
diff --git a/examples/camcontrol_VolcanoTitan/first_frame.png b/examples/camcontrol_VolcanoTitan/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..c1832748fcf4bae11bfabea0bdaf47595f9490f0
--- /dev/null
+++ b/examples/camcontrol_VolcanoTitan/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19f3f77f46324d180fa7c75466b45ff19a30261d01bf8bcb1ba4eb60c888a0f1
+size 558790
diff --git a/examples/camcontrol_VolcanoTitan/mask.mp4 b/examples/camcontrol_VolcanoTitan/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..90fb071642e6b65a91b157de351d97693f41be38
--- /dev/null
+++ b/examples/camcontrol_VolcanoTitan/mask.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f1b619fff98a78b5e39dd590727540514d93fcafbe042933086c969b0791139
+size 393530
diff --git a/examples/camcontrol_VolcanoTitan/motion_signal.mp4 b/examples/camcontrol_VolcanoTitan/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..f4117a9688e07bf43431b77a74efe3d7148fd8a0
--- /dev/null
+++ b/examples/camcontrol_VolcanoTitan/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6a2b4efbf710f71124bf2033638097150b573e10511f1b714428b597aa74d9ed
+size 2147868
diff --git a/examples/camcontrol_VolcanoTitan/prompt.txt b/examples/camcontrol_VolcanoTitan/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e20d756f1d399519ff36d456094df7b0316a2436
--- /dev/null
+++ b/examples/camcontrol_VolcanoTitan/prompt.txt
@@ -0,0 +1 @@
+A colossal titan emerges from the molten heart of an erupting volcano, its body composed of dark, jagged rock with glowing veins of fiery lava. Lightning crackles across the stormy sky, illuminating the titan's menacing form as it rises with a slow, deliberate motion. Rivers of lava cascade down the volcano's slopes, carving fiery paths through the landscape. The titan's eyes burn with an intense, molten glow, and its massive hands grip the edges of the lava pool, sending tremors through the ground. As the eruption intensifies, the titan lets out a thunderous roar, echoing across the volcanic terrain.
\ No newline at end of file
diff --git a/examples/cutdrag_cog_Monkey/first_frame.png b/examples/cutdrag_cog_Monkey/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..b08f485b2e0f7bb644f7f5693be441f6cc314300
--- /dev/null
+++ b/examples/cutdrag_cog_Monkey/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0f445e634e3a9fd5145b82535c97b06886e327f2e46bbfbc87459829f40839aa
+size 493607
diff --git a/examples/cutdrag_cog_Monkey/mask.mp4 b/examples/cutdrag_cog_Monkey/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2c46422eb00696c65c4d000d14d395166b6af601
Binary files /dev/null and b/examples/cutdrag_cog_Monkey/mask.mp4 differ
diff --git a/examples/cutdrag_cog_Monkey/motion_signal.mp4 b/examples/cutdrag_cog_Monkey/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..707c62518300087b8401fa19f4e5daa8582eea5a
--- /dev/null
+++ b/examples/cutdrag_cog_Monkey/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:79eaff99c0e0322e7623029bf2211bff3ad4bbb53fdea511fc6dfa62c116514b
+size 299834
diff --git a/examples/cutdrag_cog_Monkey/prompt.txt b/examples/cutdrag_cog_Monkey/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..01344230a4c33a1192cc71016876592db73ee846
--- /dev/null
+++ b/examples/cutdrag_cog_Monkey/prompt.txt
@@ -0,0 +1 @@
+A lively monkey energetically bounces on a neatly made bed, its limbs splayed in mid-air. As the monkey lands, the bed creases slightly under its weight, and it quickly prepares for another joyful leap, its eyes wide with excitement and mischief.
\ No newline at end of file
diff --git a/examples/cutdrag_svd_Fish/first_frame.png b/examples/cutdrag_svd_Fish/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..3d4fd5cbb643ba9aa9715d83b8bae5bb72651850
--- /dev/null
+++ b/examples/cutdrag_svd_Fish/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:67d7d2e837685604be229a94f65d087f714436fcef2f74360ffe3e07940d69a5
+size 519954
diff --git a/examples/cutdrag_svd_Fish/mask.mp4 b/examples/cutdrag_svd_Fish/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..dfe7de9c0fb10e4e2dc0088db4f1438bcd4ec5ec
Binary files /dev/null and b/examples/cutdrag_svd_Fish/mask.mp4 differ
diff --git a/examples/cutdrag_svd_Fish/motion_signal.mp4 b/examples/cutdrag_svd_Fish/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..5f30bf1ad2a3179bdbf08aa06a8aa3d34505ce57
--- /dev/null
+++ b/examples/cutdrag_svd_Fish/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92c6343da8c5028390c522aeac9a5f3a258eaf09660f74f7c977c15e40aa3485
+size 174696
diff --git a/examples/cutdrag_wan_Birds/first_frame.png b/examples/cutdrag_wan_Birds/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..924e8810c08fcfeafc82361ac20578b49219a378
--- /dev/null
+++ b/examples/cutdrag_wan_Birds/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:505a7edb31d6df4a3f4fcaeeb519dd97f11fec6c6e1d74a4da6fc43e2d7f5837
+size 271357
diff --git a/examples/cutdrag_wan_Birds/mask.mp4 b/examples/cutdrag_wan_Birds/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..eee80aadc33c5d1d26bb21b1c57cab057e80e6cd
Binary files /dev/null and b/examples/cutdrag_wan_Birds/mask.mp4 differ
diff --git a/examples/cutdrag_wan_Birds/motion_signal.mp4 b/examples/cutdrag_wan_Birds/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..da75bc85f1b922ae416999c516970b3abd10a746
--- /dev/null
+++ b/examples/cutdrag_wan_Birds/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fae72bb1f6c2a284727ef267ce910b7c6b36492b69f1fe5c5483ae2a6a60d10a
+size 482308
diff --git a/examples/cutdrag_wan_Birds/prompt.txt b/examples/cutdrag_wan_Birds/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e1fcd168afafd079bd1b379f84ad5230ff5e4b8a
--- /dev/null
+++ b/examples/cutdrag_wan_Birds/prompt.txt
@@ -0,0 +1 @@
+As the sun sets, casting a warm glow across the sky, three birds glide gracefully through the air. A majestic eagle leads the way, its powerful wings outstretched, catching the last rays of sunlight. Beside it, a swift falcon darts with precision, its sleek form cutting through the gentle breeze. Below them, a swallow flits playfully, its agile movements creating a dance against the backdrop of rolling hills and silhouetted trees. The scene is serene, with the birds moving in harmony, painting a picture of freedom and grace against the vibrant hues of the evening sky.
\ No newline at end of file
diff --git a/examples/cutdrag_wan_Cocktail/first_frame.png b/examples/cutdrag_wan_Cocktail/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..8086deffb50c5773502937fcc6da5efddb7c3093
--- /dev/null
+++ b/examples/cutdrag_wan_Cocktail/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce69e1b95630a4d7acbe34ac79a04ab38aadcd34e599a9ee971d7e66e93bdb2a
+size 468626
diff --git a/examples/cutdrag_wan_Cocktail/mask.mp4 b/examples/cutdrag_wan_Cocktail/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..858aa889caf1eb2118d5f105445a4021695b7852
Binary files /dev/null and b/examples/cutdrag_wan_Cocktail/mask.mp4 differ
diff --git a/examples/cutdrag_wan_Cocktail/motion_signal.mp4 b/examples/cutdrag_wan_Cocktail/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3a327a42aa51e6c432647630d7581321ae5df1b1
--- /dev/null
+++ b/examples/cutdrag_wan_Cocktail/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:daf025b76a4c86b5928a7bbda85ed252f03c307a4d2e823d7a52057bd73f2215
+size 138233
diff --git a/examples/cutdrag_wan_Cocktail/prompt.txt b/examples/cutdrag_wan_Cocktail/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a33cffdb13c535f0db2d6e98472ef3969e68bad2
--- /dev/null
+++ b/examples/cutdrag_wan_Cocktail/prompt.txt
@@ -0,0 +1 @@
+From the vibrant chaos of the neon-lit bar, a single spotlight finds its focus, illuminating a cocktail glass holding glistening ice and clear liquor. A thick, deep crimson cherry syrup begins to cascade from above the camera view in a stream, its rich color cutting through the light. As it descends, the syrup weaves a liquid trail through the ice, bleeding into the clear liquid to create a mesmerizing swirl of ruby red. The once-separate elements of ice, alcohol, and syrup dance and mingle, each drop transforming the drink until it becomes a singular, brilliant red jewel. The process is a silent spectacle, a concentrated moment of creation that stands out against the backdrop of the bustling, vibrant night, culminating in a beautiful and vivid cocktail ready to be enjoyed.
\ No newline at end of file
diff --git a/examples/cutdrag_wan_Gardening/first_frame.png b/examples/cutdrag_wan_Gardening/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..248b3635ade204a2f6b8623e781df784e32f7dd0
--- /dev/null
+++ b/examples/cutdrag_wan_Gardening/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f45df81e0389c0facf5c8c2308d57e69eb46d53c36727081b54646f3d19f760f
+size 513295
diff --git a/examples/cutdrag_wan_Gardening/mask.mp4 b/examples/cutdrag_wan_Gardening/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..34fa3a9c3d380399e1ee31d8200fb5484a7b861d
Binary files /dev/null and b/examples/cutdrag_wan_Gardening/mask.mp4 differ
diff --git a/examples/cutdrag_wan_Gardening/motion_signal.mp4 b/examples/cutdrag_wan_Gardening/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..dca4b9f59994f0007eb8e88c86ff5afe21ac7d55
--- /dev/null
+++ b/examples/cutdrag_wan_Gardening/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07b6987a80aa3304cbb6892b253a2dc098b7c96e904e549ffbc19c16cc5b75d6
+size 342477
diff --git a/examples/cutdrag_wan_Gardening/prompt.txt b/examples/cutdrag_wan_Gardening/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ce9ae59f6481e137f481358bf644fab0eba849ca
--- /dev/null
+++ b/examples/cutdrag_wan_Gardening/prompt.txt
@@ -0,0 +1 @@
+A woman stands in a lush garden, gently watering the vibrant plants with a hose. The sun is setting behind her, casting a warm, golden glow over the scene. She wears a wide-brimmed hat and a light dress, embodying a serene and focused presence. The garden is filled with a variety of flowers and greenery, including sunflowers and hydrangeas, all thriving under her care. In the background, a charming farmhouse sits amidst rolling fields, completing the idyllic countryside setting. As she moves, the water sparkles in the sunlight, creating a peaceful and nurturing atmosphere.
\ No newline at end of file
diff --git a/examples/cutdrag_wan_Hamburger/first_frame.png b/examples/cutdrag_wan_Hamburger/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..50ff4ca0828c2e00ada2d3be521f6ecfa31fa8cd
--- /dev/null
+++ b/examples/cutdrag_wan_Hamburger/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:123cc4eb56a15341c1fd400366ccfb9a1b9438b3cfb1a016fe072a0b80a45a8e
+size 413339
diff --git a/examples/cutdrag_wan_Hamburger/mask.mp4 b/examples/cutdrag_wan_Hamburger/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..253be902ff3b3bdee3c657c271f55c7d90348df0
Binary files /dev/null and b/examples/cutdrag_wan_Hamburger/mask.mp4 differ
diff --git a/examples/cutdrag_wan_Hamburger/motion_signal.mp4 b/examples/cutdrag_wan_Hamburger/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..ea5c41d2c14d042426ff0691846f506559586494
--- /dev/null
+++ b/examples/cutdrag_wan_Hamburger/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d94ff5c3b4fe26c297800a28ca1e69bccd388e5ff8b0562e7dcaaf664f99898
+size 264818
diff --git a/examples/cutdrag_wan_Hamburger/prompt.txt b/examples/cutdrag_wan_Hamburger/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c1daa670188716fd5ee48307e8b667047029f5df
--- /dev/null
+++ b/examples/cutdrag_wan_Hamburger/prompt.txt
@@ -0,0 +1 @@
+A woman eagerly takes a big bite of a juicy hamburger, layered with fresh lettuce, ripe tomato slices, crispy bacon, melted cheese, and red onion, all nestled between a sesame seed bun. The burger's sauce drips slightly, adding to the mouthwatering appeal. The background is softly blurred, with warm lights creating a cozy ambiance, suggesting a relaxed dining setting.
\ No newline at end of file
diff --git a/examples/cutdrag_wan_Jumping/first_frame.png b/examples/cutdrag_wan_Jumping/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..417bf26a2bfdd21cc563b9ad62044909c5501e97
--- /dev/null
+++ b/examples/cutdrag_wan_Jumping/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0642100ff03b6e11bc4dbd6c6634570f7a5b7a07434547a9004e3a0e769efc40
+size 109884
diff --git a/examples/cutdrag_wan_Jumping/mask.mp4 b/examples/cutdrag_wan_Jumping/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2440021d58fe4f9d4042c00308536e38787550bd
Binary files /dev/null and b/examples/cutdrag_wan_Jumping/mask.mp4 differ
diff --git a/examples/cutdrag_wan_Jumping/motion_signal.mp4 b/examples/cutdrag_wan_Jumping/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..49cca954d7054c3f2b73bfba1c966c75844acd2c
Binary files /dev/null and b/examples/cutdrag_wan_Jumping/motion_signal.mp4 differ
diff --git a/examples/cutdrag_wan_Jumping/prompt.txt b/examples/cutdrag_wan_Jumping/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..70c9d3defdf96664ef427e35fde878cd3b41c135
--- /dev/null
+++ b/examples/cutdrag_wan_Jumping/prompt.txt
@@ -0,0 +1 @@
+A young child stands beside a concrete block on a vast, empty runway. The child, wearing a blue shirt and shorts, looks around curiously. Suddenly, with a burst of energy, he begins to jump up and down, his laughter echoing across the open space. The sky is overcast, and the distant trees frame the scene, adding a sense of freedom and adventure to his playful movements.
\ No newline at end of file
diff --git a/examples/cutdrag_wan_Monkey/first_frame.png b/examples/cutdrag_wan_Monkey/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..b08f485b2e0f7bb644f7f5693be441f6cc314300
--- /dev/null
+++ b/examples/cutdrag_wan_Monkey/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0f445e634e3a9fd5145b82535c97b06886e327f2e46bbfbc87459829f40839aa
+size 493607
diff --git a/examples/cutdrag_wan_Monkey/mask.mp4 b/examples/cutdrag_wan_Monkey/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..304eed69c558807fc2e11b6c72ad57f8c894e2ef
Binary files /dev/null and b/examples/cutdrag_wan_Monkey/mask.mp4 differ
diff --git a/examples/cutdrag_wan_Monkey/motion_signal.mp4 b/examples/cutdrag_wan_Monkey/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..ed42945993334224e72b7656c650feec640f5447
--- /dev/null
+++ b/examples/cutdrag_wan_Monkey/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3cabb839e51bc6f348f96a15ba683b41b63e29b0ed46559791651b60a5802011
+size 337998
diff --git a/examples/cutdrag_wan_Monkey/prompt.txt b/examples/cutdrag_wan_Monkey/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..01344230a4c33a1192cc71016876592db73ee846
--- /dev/null
+++ b/examples/cutdrag_wan_Monkey/prompt.txt
@@ -0,0 +1 @@
+A lively monkey energetically bounces on a neatly made bed, its limbs splayed in mid-air. As the monkey lands, the bed creases slightly under its weight, and it quickly prepares for another joyful leap, its eyes wide with excitement and mischief.
\ No newline at end of file
diff --git a/examples/cutdrag_wan_Owl/first_frame.png b/examples/cutdrag_wan_Owl/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..7ad4933987842c35b63a39c0da496925327840ce
--- /dev/null
+++ b/examples/cutdrag_wan_Owl/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c2cee4c7b54069b690e46d1074f6177de6164f034c28770ba7c5ea5674ea7db5
+size 239576
diff --git a/examples/cutdrag_wan_Owl/mask.mp4 b/examples/cutdrag_wan_Owl/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a5458e6ff1e3657195d2fe9740cff2ccb084f0ad
Binary files /dev/null and b/examples/cutdrag_wan_Owl/mask.mp4 differ
diff --git a/examples/cutdrag_wan_Owl/motion_signal.mp4 b/examples/cutdrag_wan_Owl/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0d96c94dfa4e1cf0086a82f900452aa756f3d073
--- /dev/null
+++ b/examples/cutdrag_wan_Owl/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9b3413a0c0228bfeaebe4c1a818f10aeec7a7f3b667c6103b75ea72478724c1e
+size 134971
diff --git a/examples/cutdrag_wan_Owl/prompt.txt b/examples/cutdrag_wan_Owl/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..17a9976c1dcee69b99c0bc714bcba7fa48db47a7
--- /dev/null
+++ b/examples/cutdrag_wan_Owl/prompt.txt
@@ -0,0 +1 @@
+A majestic snowy owl perches gracefully on a gnarled branch, its pristine white feathers adorned with delicate black speckles. The owl's piercing yellow eyes are wide and alert, scanning the surroundings with a sense of calm authority. As a gentle breeze rustles through the leaves, the owl remains poised, its sharp talons gripping the branch securely. The dark, blurred background accentuates the owl's striking presence, creating a serene yet powerful scene in the quiet of the night.
\ No newline at end of file
diff --git a/examples/cutdrag_wan_Rhino/first_frame.png b/examples/cutdrag_wan_Rhino/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..2e9427031ce500cd45878a369a2527e99ec52e2d
--- /dev/null
+++ b/examples/cutdrag_wan_Rhino/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:98ade3a4f8c6971434d1cef44592bf873683188c21e96040da376cd4ff7fb292
+size 349188
diff --git a/examples/cutdrag_wan_Rhino/mask.mp4 b/examples/cutdrag_wan_Rhino/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..99178074159f15bed96f0e0830a8f7ed90b4d2a3
Binary files /dev/null and b/examples/cutdrag_wan_Rhino/mask.mp4 differ
diff --git a/examples/cutdrag_wan_Rhino/motion_signal.mp4 b/examples/cutdrag_wan_Rhino/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..6036d331223ae2c1c25354d9bf5e508621d17909
--- /dev/null
+++ b/examples/cutdrag_wan_Rhino/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:084185b67d107ab7974e1f88364fa0979755099c9d2c722981256caca2ce0d5a
+size 206075
diff --git a/examples/cutdrag_wan_Rhino/prompt.txt b/examples/cutdrag_wan_Rhino/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4b1940707024239d6561b1dfe9889c736f7aa303
--- /dev/null
+++ b/examples/cutdrag_wan_Rhino/prompt.txt
@@ -0,0 +1 @@
+The scene opens on a breathtaking savanna, bathed in the fiery hues of a twilight sky, where a small boy stands in silent awe next to an enormous rhino. With an electric burst of energy, the boy's face breaks into a wide grin, and he launches himself into a single, impossible leap, a blur of motion against the dramatic landscape. For a fleeting moment, he hangs in the air, a tiny, determined silhouette before landing with a soft thud on the rhino's massive back. There is no fear, only pure, unadulterated joy as a triumphant whoop escapes his lips, echoing across the plains. He settles comfortably on his new perch, his small hands holding on to the bristly hide, a perfect picture of thrilled connection. The rhino continues to stand in its quiet dignity, seemingly unfazed, as the boy's infectious laughter rings out, a perfect harmony with the fading light of the African twilight.
\ No newline at end of file
diff --git a/examples/cutdrag_wan_Surfing/first_frame.png b/examples/cutdrag_wan_Surfing/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..fc5337b6d395eb7d9982d2346d99b7a2f8cb4045
--- /dev/null
+++ b/examples/cutdrag_wan_Surfing/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b7502c75e09eb7ef787cc776def7446c7fb2fb04294c9d3da9461b5bf9d5d56
+size 416054
diff --git a/examples/cutdrag_wan_Surfing/mask.mp4 b/examples/cutdrag_wan_Surfing/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..51ef74371e7520927e81093dc5da2de615631342
Binary files /dev/null and b/examples/cutdrag_wan_Surfing/mask.mp4 differ
diff --git a/examples/cutdrag_wan_Surfing/motion_signal.mp4 b/examples/cutdrag_wan_Surfing/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..fec9eb1e6fb41f231d36d06a1dc1a7caf3294e6e
--- /dev/null
+++ b/examples/cutdrag_wan_Surfing/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:baa3a2761e899a545751456ed2151b0d6b7be838efa4f22c7f0c579baaf6fc89
+size 464455
diff --git a/examples/cutdrag_wan_Surfing/prompt.txt b/examples/cutdrag_wan_Surfing/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d3049834c82ac503f8a3584dac0728db53fa3461
--- /dev/null
+++ b/examples/cutdrag_wan_Surfing/prompt.txt
@@ -0,0 +1 @@
+A surfer expertly rides inside the barrel of a powerful, curling wave. The sunlight glistens off the ocean surface, casting a shimmering glow across the water. As the wave crashes around him, he maintains perfect balance, skillfully maneuvering his board to stay within the wave's hollow. The sky is a soft blend of blues and whites, adding to the serene yet exhilarating atmosphere of the scene.
\ No newline at end of file
diff --git a/examples/cutdrag_wan_TimeSquares/first_frame.png b/examples/cutdrag_wan_TimeSquares/first_frame.png
new file mode 100644
index 0000000000000000000000000000000000000000..1045befe84e1f58467aabfad3436fb963ee0cefe
--- /dev/null
+++ b/examples/cutdrag_wan_TimeSquares/first_frame.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd00d595c7bbcde120b349d866f1923ccf630c40ec3414fe5c912d59b0e0c0e5
+size 610625
diff --git a/examples/cutdrag_wan_TimeSquares/mask.mp4 b/examples/cutdrag_wan_TimeSquares/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2f091ff22540fb4d13a92c550c7dc5de761d30ea
Binary files /dev/null and b/examples/cutdrag_wan_TimeSquares/mask.mp4 differ
diff --git a/examples/cutdrag_wan_TimeSquares/motion_signal.mp4 b/examples/cutdrag_wan_TimeSquares/motion_signal.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7303f8aec809640d18b60d6ed0a40516e7d3fbbb
--- /dev/null
+++ b/examples/cutdrag_wan_TimeSquares/motion_signal.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8cbb0bcdbe3b8fe8de07f5ee7777423e0e1e1b73e9cad7458e54a5ce1a573e33
+size 485465
diff --git a/examples/cutdrag_wan_TimeSquares/prompt.txt b/examples/cutdrag_wan_TimeSquares/prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..22a526baa9168abc1da72cd335f2d496c4d54a9d
--- /dev/null
+++ b/examples/cutdrag_wan_TimeSquares/prompt.txt
@@ -0,0 +1 @@
+A woman in a red shirt stands smiling in the bustling heart of Times Square, surrounded by towering skyscrapers and vibrant digital billboards. The scene is alive with the movement of people walking along the sidewalks, while iconic yellow taxis navigate the busy streets. The bright blue sky above adds to the lively atmosphere, as the woman enjoys the energetic vibe of this iconic New York City location.
\ No newline at end of file
diff --git a/huggingface_space/README.md b/huggingface_space/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..646d28b4c4e4e116f520ecf8b1001384a535bd31
--- /dev/null
+++ b/huggingface_space/README.md
@@ -0,0 +1,30 @@
+# Hugging Face Space Template
+
+Files in this folder provide a turnkey Gradio interface for Time-to-Move using the Wan 2.2 backbone. Copy the folder into a new 🤗 Space repository (or set the Space’s working directory to this folder) to expose an interactive demo.
+
+## Files
+
+- `app.py` – Gradio app that wraps `WanImageToVideoTTMPipeline`. Users can run any folder under `examples/` or upload their own `first_frame.png`, `mask.mp4`, and `motion_signal.mp4`.
+- `requirements.txt` – Runtime dependencies; the Space installs these automatically.
+
+## Deployment steps
+
+1. Create a Space (`huggingface-cli repo create my-org/time-to-move --type=space --sdk=gradio`).
+2. Copy the entire `huggingface_space/` directory into that repository and push.
+3. In the Space settings:
+ - Select an **A100** or **A10** runtime (Wan 2.2 needs >30 GB VRAM).
+ - Add `WAN_MODEL_ID` if you mirror Wan internally (default pulls from `Wan-AI/Wan2.2-I2V-A14B-Diffusers`).
+ - Add an access token secret (`HF_TOKEN`) if the Wan checkpoint is private.
+ - Optional: set `TTM_EXAMPLES_DIR` if the examples folder lives somewhere else.
+
+The app automatically discovers any folders inside `examples/` that contain `first_frame.png`, `mask.mp4`, `motion_signal.mp4`, and `prompt.txt`, so include whichever cut-and-drag sets you want to showcase.
+
+## Local smoke test
+
+```bash
+pip install -r requirements.txt
+WAN_MODEL_ID=Wan-AI/Wan2.2-I2V-A14B-Diffusers \
+python app.py
+```
+
+If you can generate one of the bundled examples locally, the same setup will work once pushed to Hugging Face Spaces.
diff --git a/huggingface_space/app.py b/huggingface_space/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..227f59c6d967655a58a184d97d66bfaff8ea84f9
--- /dev/null
+++ b/huggingface_space/app.py
@@ -0,0 +1,314 @@
+"""
+Gradio Space entrypoint for Time-to-Move (Wan 2.2 backbone).
+
+The UI allows users to run the provided cut-and-drag examples or upload their own
+`first_frame.png`, `mask.mp4`, and `motion_signal.mp4` triplet. All preprocessing
+matches the `run_wan.py` script in the main repository.
+"""
+
+import base64
+import os
+import tempfile
+from pathlib import Path
+from typing import Dict, Optional, Tuple
+
+import gradio as gr
+import torch
+from diffusers.utils import export_to_video, load_image
+
+from pipelines.utils import compute_hw_from_area
+from pipelines.wan_pipeline import WanImageToVideoTTMPipeline
+
+
+WAN_MODEL_ID = os.getenv("WAN_MODEL_ID", "Wan-AI/Wan2.2-I2V-A14B-Diffusers")
+EXAMPLES_DIR = Path(os.getenv("TTM_EXAMPLES_DIR", "examples"))
+DEFAULT_NEGATIVE_PROMPT = (
+ "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,"
+ "低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,"
+ "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+)
+
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
+
+_PIPELINE: Optional[WanImageToVideoTTMPipeline] = None
+_MOD_VALUE: Optional[int] = None
+
+
+def _build_example_index() -> Dict[str, Dict[str, str]]:
+ """Scan `examples/` for Wan-compatible folders that contain the required files."""
+ index: Dict[str, Dict[str, str]] = {}
+ if not EXAMPLES_DIR.exists():
+ return index
+
+ for folder in sorted(EXAMPLES_DIR.iterdir()):
+ if not folder.is_dir():
+ continue
+ image = folder / "first_frame.png"
+ mask = folder / "mask.mp4"
+ motion = folder / "motion_signal.mp4"
+ prompt_file = folder / "prompt.txt"
+ if not (image.exists() and mask.exists() and motion.exists()):
+ continue
+ index[folder.name] = {
+ "folder": str(folder.resolve()),
+ "prompt": prompt_file.read_text(encoding="utf-8").strip() if prompt_file.exists() else "",
+ }
+ return index
+
+
+EXAMPLE_INDEX = _build_example_index()
+
+
+def _ensure_pipeline() -> WanImageToVideoTTMPipeline:
+ """Lazy-load the Wan Time-to-Move pipeline."""
+ global _PIPELINE, _MOD_VALUE
+ if _PIPELINE is None:
+ pipe = WanImageToVideoTTMPipeline.from_pretrained(WAN_MODEL_ID, torch_dtype=DTYPE)
+ pipe.vae.enable_tiling()
+ pipe.vae.enable_slicing()
+ pipe.to(DEVICE)
+ _PIPELINE = pipe
+ # Height/width must be multiples of vae_scale_factor * patch_size.
+ _MOD_VALUE = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ return _PIPELINE
+
+
+def _save_video_payload(payload, tmpdir: Path, filename: str) -> str:
+ """
+ Persist a video uploaded through Gradio to disk and return its path.
+
+ Hugging Face Spaces provide uploads either as temp file paths or as base64 data URIs.
+ """
+ target = tmpdir / filename
+
+ if payload is None:
+ raise gr.Error(f"Missing upload for {filename}.")
+
+ if isinstance(payload, str) and Path(payload).exists():
+ data = Path(payload).read_bytes()
+ target.write_bytes(data)
+ return str(target)
+
+ if isinstance(payload, dict):
+ # payload["name"] may already point to a temp file on disk.
+ potential_path = payload.get("name")
+ if potential_path and Path(potential_path).exists():
+ data = Path(potential_path).read_bytes()
+ target.write_bytes(data)
+ return str(target)
+
+ raw_data = payload.get("data")
+ if raw_data is None:
+ raise gr.Error(f"Could not read data for {filename}.")
+
+ if isinstance(raw_data, str):
+ # Format: data:video/mp4;base64,AAA...
+ if raw_data.startswith("data:"):
+ raw_data = raw_data.split(",", 1)[1]
+ file_bytes = base64.b64decode(raw_data)
+ else:
+ file_bytes = raw_data
+ target.write_bytes(file_bytes)
+ return str(target)
+
+ raise gr.Error(f"Unsupported upload format for {filename}.")
+
+
+def _prepare_inputs(
+ example_name: str,
+ prompt: str,
+ negative_prompt: str,
+ custom_image,
+ custom_mask,
+ custom_motion,
+) -> Tuple:
+ """
+ Determine which inputs to feed into the pipeline.
+
+ Returns (prompt, negative_prompt, image, mask_path, motion_path).
+ """
+ negative_prompt = (negative_prompt or "").strip() or DEFAULT_NEGATIVE_PROMPT
+
+ if example_name != "custom":
+ meta = EXAMPLE_INDEX.get(example_name)
+ if not meta:
+ raise gr.Error(f"Example '{example_name}' not found in {EXAMPLES_DIR}.")
+ folder = Path(meta["folder"])
+ image = load_image(folder / "first_frame.png")
+ mask_path = str((folder / "mask.mp4").resolve())
+ motion_path = str((folder / "motion_signal.mp4").resolve())
+ resolved_prompt = (prompt or "").strip() or meta["prompt"]
+ if not resolved_prompt:
+ raise gr.Error("Prompt cannot be empty for example runs.")
+ return resolved_prompt, negative_prompt, image, mask_path, motion_path
+
+ # Custom upload path
+ if custom_image is None:
+ raise gr.Error("Upload a first frame (PNG/JPG) for custom mode.")
+ resolved_prompt = (prompt or "").strip()
+ if not resolved_prompt:
+ raise gr.Error("Prompt cannot be empty for custom runs.")
+
+ tmpdir = Path(tempfile.mkdtemp(prefix="ttm_space_inputs_"))
+ mask_path = _save_video_payload(custom_mask, tmpdir, "mask.mp4")
+ motion_path = _save_video_payload(custom_motion, tmpdir, "motion_signal.mp4")
+ return resolved_prompt, negative_prompt, custom_image, mask_path, motion_path
+
+
+def generate_video(
+ example_name: str,
+ prompt: str,
+ negative_prompt: str,
+ tweak_index: int,
+ tstrong_index: int,
+ num_inference_steps: int,
+ guidance_scale: float,
+ num_frames: int,
+ max_area: int,
+ seed: int,
+ custom_image,
+ custom_mask,
+ custom_motion,
+):
+ """Main callback used by Gradio."""
+ pipe = _ensure_pipeline()
+ prompt, negative_prompt, image, mask_path, motion_path = _prepare_inputs(
+ example_name, prompt, negative_prompt, custom_image, custom_mask, custom_motion
+ )
+
+ tweak_index = int(tweak_index)
+ tstrong_index = int(tstrong_index)
+ num_inference_steps = int(num_inference_steps)
+ num_frames = int(num_frames)
+ guidance_scale = float(guidance_scale)
+ seed = int(seed)
+ if not (0 <= tweak_index <= tstrong_index <= num_inference_steps):
+ raise gr.Error("Require 0 ≤ tweak-index ≤ tstrong-index ≤ num_inference_steps.")
+
+ max_area = int(max_area)
+ mod_value = _MOD_VALUE or (pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1])
+ height, width = compute_hw_from_area(image.height, image.width, max_area, mod_value)
+ if hasattr(image, "mode") and image.mode != "RGB":
+ image = image.convert("RGB")
+ image = image.resize((width, height))
+
+ generator_device = DEVICE if DEVICE.startswith("cuda") else "cpu"
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
+
+ with torch.inference_mode():
+ result = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ motion_signal_video_path=motion_path,
+ motion_signal_mask_path=mask_path,
+ tweak_index=tweak_index,
+ tstrong_index=tstrong_index,
+ )
+
+ frames = result.frames[0]
+ output_dir = Path(tempfile.mkdtemp(prefix="ttm_space_output_"))
+ output_path = output_dir / "ttm.mp4"
+ export_to_video(frames, str(output_path), fps=16)
+
+ status = (
+ f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}\n"
+ f"Resolution: {height}x{width}, frames: {num_frames}, guidance: {guidance_scale}"
+ )
+ return str(output_path), status
+
+
+def build_ui() -> gr.Blocks:
+ example_choices = sorted(EXAMPLE_INDEX.keys())
+ default_example = example_choices[0] if example_choices else "custom"
+
+ with gr.Blocks(title="Time-to-Move (Wan 2.2)") as demo:
+ gr.Markdown(
+ "### Time-to-Move (Wan 2.2)\n"
+ "Generate motion-controlled videos by combining a still frame with a cut-and-drag motion signal. "
+ "Select one of the bundled examples or upload your own trio of files."
+ )
+
+ with gr.Row():
+ example_dropdown = gr.Dropdown(
+ choices=example_choices + ["custom"],
+ value=default_example,
+ label="Example preset",
+ info="Choose a prepackaged example or switch to 'custom' to upload your own inputs.",
+ )
+ prompt_box = gr.Textbox(
+ label="Prompt",
+ lines=4,
+ placeholder="Enter the text prompt (auto-filled for examples).",
+ )
+ negative_prompt_box = gr.Textbox(
+ label="Negative prompt",
+ lines=4,
+ value=DEFAULT_NEGATIVE_PROMPT,
+ )
+
+ with gr.Row():
+ image_input = gr.Image(label="first_frame (custom only)", type="pil")
+ mask_input = gr.Video(label="mask.mp4 (custom only)")
+ motion_input = gr.Video(label="motion_signal.mp4 (custom only)")
+
+ with gr.Row():
+ tweak_slider = gr.Slider(0, 20, value=3, step=1, label="tweak-index")
+ tstrong_slider = gr.Slider(0, 50, value=7, step=1, label="tstrong-index")
+ steps_slider = gr.Slider(10, 50, value=50, step=1, label="num_inference_steps")
+ guidance_slider = gr.Slider(1.0, 8.0, value=3.5, step=0.1, label="guidance_scale")
+
+ with gr.Row():
+ frames_slider = gr.Slider(21, 81, value=81, step=1, label="num_frames")
+ area_slider = gr.Slider(
+ 256 * 256,
+ 640 * 1152,
+ value=480 * 832,
+ step=64,
+ label="max pixel area (height*width)",
+ )
+ seed_box = gr.Number(label="Seed", value=0, precision=0)
+
+ generate_button = gr.Button("Generate video", variant="primary")
+ output_video = gr.Video(label="Generated video", autoplay=True, height=512)
+ status_box = gr.Markdown()
+
+ generate_button.click(
+ fn=generate_video,
+ inputs=[
+ example_dropdown,
+ prompt_box,
+ negative_prompt_box,
+ tweak_slider,
+ tstrong_slider,
+ steps_slider,
+ guidance_slider,
+ frames_slider,
+ area_slider,
+ seed_box,
+ image_input,
+ mask_input,
+ motion_input,
+ ],
+ outputs=[output_video, status_box],
+ )
+
+ info = "\n".join(f"- **{name}**: `{meta['folder']}`" for name, meta in EXAMPLE_INDEX.items())
+ with gr.Accordion("Available packaged examples", open=False):
+ gr.Markdown(info or "No example folders detected. Upload custom inputs instead.")
+
+ return demo
+
+
+app = build_ui()
+
+if __name__ == "__main__":
+ # Enable queuing to support concurrent users on Spaces.
+ app.queue(max_size=2).launch(server_name="0.0.0.0", server_port=7860)
diff --git a/huggingface_space/requirements.txt b/huggingface_space/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f6f39457f60f57c992deb6c67b6c2e866d568041
--- /dev/null
+++ b/huggingface_space/requirements.txt
@@ -0,0 +1,14 @@
+accelerate>=0.26.0
+diffusers>=0.29.0
+ftfy>=6.1.3
+gradio>=4.24.0
+huggingface-hub>=0.23.0
+imageio>=2.34.0
+imageio-ffmpeg>=0.5.0
+opencv-python>=4.9.0
+pillow>=10.2.0
+protobuf>=4.25.2
+sentencepiece>=0.1.99
+torch>=2.1.2
+transformers>=4.39.0
+xformers>=0.0.23.post1 ; platform_system == "Linux"
diff --git a/pipelines/cog_pipeline.py b/pipelines/cog_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec03bbe537c5ec945e08683e83b57e444d1f2185
--- /dev/null
+++ b/pipelines/cog_pipeline.py
@@ -0,0 +1,524 @@
+# Copyright 2025 Noam Rotstein
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Adapted from Hugging Face Diffusers (Apache-2.0):
+# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+
+try:
+ from dataclasses import dataclass
+ import math
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+ import torch
+ from transformers import T5EncoderModel, T5Tokenizer
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+ from diffusers.image_processor import PipelineImageInput
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+ from diffusers.utils import (
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ )
+ from diffusers.utils.torch_utils import randn_tensor
+ from diffusers.video_processor import VideoProcessor
+ from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import retrieve_timesteps
+ from diffusers import CogVideoXImageToVideoPipeline
+
+ import torch.nn.functional as F
+ from pipelines.utils import load_video_to_tensor
+
+except ImportError as e:
+ raise ImportError(f"Required module not found: {e}. Please install it before running this script. "
+ f"For installation instructions, see: https://github.com/zai-org/CogVideo")
+
+
+try:
+ if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+ XLA_AVAILABLE = True
+ else:
+ XLA_AVAILABLE = False
+except ImportError:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+"""
+
+
+class CogVideoXImageToVideoTTMPipeline(CogVideoXImageToVideoPipeline):
+ r"""
+ Pipeline for image-to-video generation using CogVideoX combined with Time to Move (TTM).
+ This model inherits from [`CogVideoXImageToVideoPipeline`].
+ """
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModel,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ ):
+ super().__init__(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ )
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+
+ @torch.no_grad()
+ def encode_frames(self, frames: torch.Tensor, vae_scale_factor: float = None) -> torch.Tensor:
+ """Encode video frames into latent space with shape (B, F, C, H, W). Input shape (B, C, F, H, W), expected range [-1, 1]."""
+ latents = self.vae.encode(frames)[0].sample()
+ # latents = self.vae.encode(frames)[0].mode()
+ vae_scale_factor = vae_scale_factor or self.vae_scaling_factor_image
+ latents = latents * vae_scale_factor
+ return latents.permute(0, 2, 1, 3, 4).contiguous() # shape (B, C, F, H, W) -> (B, F, C, H, W)
+
+
+ def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W'].
+ T_latent groups frames by the temporal VAE downsample factor k = vae_scale_factor_temporal:
+ [0], [1..k], [k+1..2k], ...
+ """
+ k = self.vae_scale_factor_temporal
+
+ mask0 = mask[0:1] # [1,1,H,W]
+ mask1 = mask[1::k] # [T'-1,1,H,W]
+ sampled = torch.cat([mask0, mask1], dim=0) # [T',1,H,W]
+ pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0)
+
+ # Up-sample spatially to match latent spatial resolution
+ s = self.vae_scale_factor_spatial
+ H_latent = pooled.shape[-2] // s
+ W_latent = pooled.shape[-1] // s
+ pooled = F.interpolate(pooled, size=(pooled.shape[2], H_latent, W_latent), mode="nearest")
+
+ # Back to [1, T_latent, 1, H, W]
+ latent_mask = pooled.permute(0, 2, 1, 3, 4)
+
+ return latent_mask
+
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ motion_signal_video_path: Optional[str] = None,
+ motion_signal_mask_path: Optional[str] = None,
+ tweak_index: int = 0,
+ tstrong_index: int = 0
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
+ num_frames (`int`, defaults to `48`):
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
+ needs to be satisfied is that of divisibility mentioned above.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `226`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+ motion_signal_video_path (`str`):
+ Path to the video file containing the motion signal to guide the motion of the generated video.
+ It should be a crude version of the reference video, with pixels with motion dragged to their target.
+ motion_signal_mask_path (`str`):
+ Path to the mask video file containing the motion mask of TTM.
+ The mask should be a binary with the conditioning motion pixels being 1 and the rest being 0.
+ tweak_index (`int`):
+ The index of the tweak, from which the denoising process starts.
+ tstrong_index (`int`):
+ The index of the tweak, from which the denoising process starts in the motion conditioned region.
+ Examples:
+
+ Returns:
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_frames = num_frames or self.transformer.config.sample_frames
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ image=image,
+ prompt=prompt,
+ height=height,
+ width=width,
+ negative_prompt=negative_prompt,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ latents=latents,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ if motion_signal_mask_path is None:
+ raise ValueError("`motion_signal_mask_path` is required for TTM.")
+ if motion_signal_video_path is None:
+ raise ValueError("`motion_signal_video_path` is required for TTM.")
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
+ patch_size_t = self.transformer.config.patch_size_t
+ additional_frames = 0
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
+ additional_frames = patch_size_t - latent_frames % patch_size_t
+ num_frames += additional_frames * self.vae_scale_factor_temporal
+
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
+ device, dtype=prompt_embeds.dtype
+ )
+
+ latent_channels = self.transformer.config.in_channels // 2
+ latents, image_latents = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Create ofs embeds if required
+ ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
+
+ # 9. Initialize for TTM
+ ref_vid = load_video_to_tensor(motion_signal_video_path).to(device=device) # shape [1, C, T, H, W]
+ refB, refC, refT, refH, refW = ref_vid.shape
+ ref_vid = F.interpolate(
+ ref_vid.permute(0, 2, 1, 3, 4).reshape(refB*refT, refC, refH, refW),
+ size=(height, width), mode="bicubic", align_corners=True,
+ ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
+
+ ref_vid = self.video_processor.normalize(ref_vid.to(dtype=self.vae.dtype)) # Normalize and convert dtype for VAE encoding
+ ref_latents = self.encode_frames(ref_vid).float().detach() # shape [1, T, C, H, W]
+
+ ref_mask = load_video_to_tensor(motion_signal_mask_path).to(device=device) # shape [1, C, T, H, W]
+ mB, mC, mT, mH, mW = ref_mask.shape
+
+ ref_mask = F.interpolate(
+ ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW),
+ size=(height, width), mode="nearest",
+ ).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4)
+ ref_mask = ref_mask[0].permute(1, 0, 2, 3).contiguous() # (1, C, T, H, W) -> (T, H, W, 1)
+
+ if len(ref_mask.shape) == 4:
+ ref_mask = ref_mask.unsqueeze(0)
+
+ ref_mask = ref_mask[0,:,:1].contiguous() # (1, T, C, H, W) -> (T, 1, H, W)
+ ref_mask = (ref_mask > 0.5).float().max(dim=1, keepdim=True)[0] # [T, 1, H, W]
+ motion_mask = self.convert_rgb_mask_to_latent_mask(ref_mask) # [1, T, 1, H, W]
+ background_mask = 1.0 - motion_mask
+
+ if tweak_index >= 0:
+ tweak = self.scheduler.timesteps[tweak_index]
+ fixed_noise = randn_tensor(
+ ref_latents.shape,
+ generator=generator,
+ device=ref_latents.device,
+ dtype=ref_latents.dtype,
+ )
+ noisy_latents = self.scheduler.add_noise(ref_latents, fixed_noise, tweak.long())
+ latents = noisy_latents.to(dtype=latents.dtype, device=latents.device)
+ else:
+ tweak = torch.tensor(-1)
+ fixed_noise = randn_tensor(
+ ref_latents.shape,
+ generator=generator,
+ device=ref_latents.device,
+ dtype=ref_latents.dtype,
+ )
+ tweak_index = 0
+
+ # 10. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # logging
+ # ------------------------------------------------------------------
+ with self.progress_bar(total=len(timesteps) - tweak_index) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps[tweak_index:]):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ ofs=ofs_emb,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents, old_pred_original_sample = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+
+ # In between tweak and tstrong, replace mask with noisy reference latents
+ in_between_tweak_tstrong = (i+tweak_index) < tstrong_index
+
+ if in_between_tweak_tstrong:
+ if i+tweak_index+1 < len(timesteps):
+ prev_t = timesteps[i+tweak_index+1]
+ noisy_latents = self.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to(dtype=latents.dtype, device=latents.device)
+ latents = latents * background_mask + noisy_latents * motion_mask
+ else:
+ latents = latents * background_mask + ref_latents * motion_mask
+
+ latents = latents.to(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ # Discard any padding frames that were added for CogVideoX 1.5
+ latents = latents[:, additional_frames:]
+ frames = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=frames, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CogVideoXPipelineOutput(
+ frames=video,
+ )
\ No newline at end of file
diff --git a/pipelines/svd_pipeline.py b/pipelines/svd_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..e71371b30f42ebe30decd43abb5e7791a8c95327
--- /dev/null
+++ b/pipelines/svd_pipeline.py
@@ -0,0 +1,624 @@
+# Copyright 2025 Noam Rotstein
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Adapted from Hugging Face Diffusers (Apache-2.0):
+# https://github.com/huggingface/diffusers/blob/8abc7aeb715c0149ee0a9982b2d608ce97f55215/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py#L147
+
+try:
+ import inspect
+ from dataclasses import dataclass
+ from typing import Callable, Dict, List, Optional, Union
+ import numpy as np
+ import PIL.Image
+ import torch
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+ from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
+ from diffusers.schedulers import EulerDiscreteScheduler
+ from diffusers.utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
+ from diffusers.utils.torch_utils import randn_tensor
+ from diffusers.video_processor import VideoProcessor
+ import torch.nn.functional as F
+ from diffusers.pipelines.stable_video_diffusion import StableVideoDiffusionPipeline
+ from pipelines.utils import load_video_to_tensor
+
+except ImportError as e:
+ raise ImportError(f"Required module not found: {e}. Please install it before running this script. "
+ f"For installation instructions, see:https://github.com/Stability-AI/generative-models")
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+"""
+
+
+def _append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+@dataclass
+class StableVideoDiffusionPipelineOutput(BaseOutput):
+ r"""
+ Output class for Stable Video Diffusion pipeline.
+
+ Args:
+ frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
+ num_frames, height, width, num_channels)`.
+ """
+
+ frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
+
+
+class StableVideoDiffusionTTMPipeline(StableVideoDiffusionPipeline):
+ r"""
+ Pipeline to generate video from an input image using Stable Video Diffusion combined with Time to Move (TTM).
+ This model inherits from [`StableVideoDiffusionPipeline`].
+ """
+
+ model_cpu_offload_seq = "image_encoder->unet->vae"
+ _callback_tensor_inputs = ["latents"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKLTemporalDecoder,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNetSpatioTemporalConditionModel,
+ scheduler: EulerDiscreteScheduler,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__(
+ vae=vae,
+ image_encoder=image_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+
+ self.register_modules(
+ vae=vae,
+ image_encoder=image_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
+
+
+ def encode_frames(self, frames: torch.Tensor, num_frames: int, encode_chunk_size: int = 14):
+ """
+ Args:
+ frames: [B, C, T, H, W] tensor, preprocessed to VAE's expected range (e.g., [-1, 1]).
+ num_frames: T (used for reshaping back).
+ encode_chunk_size: process at most this many frames at a time to avoid OOM.
+
+ Returns:
+ latents: [B, T, C_latent, h, w], multiplied by self.vae.config.scaling_factor.
+
+ Notes:
+ - Stochastic: samples from posterior (latent_dist.sample()).
+ - If the VAE's compiled module hides the signature, we inspect the original .forward
+ and pass num_frames only if it's accepted (same pattern as decode).
+ """
+ if frames.dim() != 5:
+ raise ValueError(f"Expected frames with shape [B, C, T, H, W], got {list(frames.shape)}")
+ B, C, T, H, W = frames.shape
+
+ # [B, C, T, H, W] -> [B, T, C, H, W] -> [B*T, C, H, W]
+ frames_bt = frames.permute(0, 2, 1, 3, 4).reshape(-1, C, H, W)
+
+ # Use the *encode* signature (decoder may accept num_frames, encoder usually doesn't)
+ encode_fn = self.vae._orig_mod.encode if hasattr(self.vae, "_orig_mod") else self.vae.encode
+ try:
+ accepts_num_frames = ("num_frames" in inspect.signature(encode_fn).parameters)
+ except (TypeError, ValueError):
+ # Signature might be obscured by wrappers/compilation; be conservative
+ accepts_num_frames = False
+
+ latents_chunks = []
+ for i in range(0, frames_bt.shape[0], encode_chunk_size):
+ chunk = frames_bt[i : i + encode_chunk_size]
+
+ # match VAE device/dtype to avoid implicit casts
+ chunk = chunk.to(device=self.vae.device, dtype=self.vae.dtype)
+
+ encode_kwargs = {}
+ if accepts_num_frames:
+ # This will normally be False for AutoencoderKLTemporalDecoder.encode()
+ encode_kwargs["num_frames"] = chunk.shape[0]
+
+ # Be robust to unexpected wrappers hiding the signature
+ try:
+ enc_out = self.vae.encode(chunk, **encode_kwargs)
+ except TypeError as e:
+ if "unexpected keyword argument 'num_frames'" in str(e):
+ enc_out = self.vae.encode(chunk)
+ else:
+ raise
+
+ posterior = enc_out.latent_dist # DiagonalGaussianDistribution
+ latents_chunks.append(posterior.sample())
+
+ latents = torch.cat(latents_chunks, dim=0) # [B*T, C_lat, h, w]
+ latents = latents * self.vae.config.scaling_factor
+
+ # [B*T, C_lat, h, w] -> [B, T, C_lat, h, w]
+ latents = latents.reshape(B, num_frames, *latents.shape[1:])
+
+ return latents
+
+ def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor, first_different=True) -> torch.Tensor:
+ """
+ Args:
+ mask: [T, 1, H, W] tensor (0/1 or any float in [0,1]).
+ Returns:
+ latent_mask: [1, T_latent, 1, H, W], where
+ T_latent = ceil(T / self.vae_scale_factor_temporal)
+ For CogVideoX-style VAE (k=4), groups are [0], [1-4], [5-8], ..., achieved by
+ pre-padding zeros at the start before max-pooling with stride=k.
+ """
+ T, _, H, W = mask.shape
+
+ k = self.vae_scale_factor_temporal
+ # Pre-pad zeros along time so that the first pooled window corresponds to frame 0 alone
+ if first_different:
+ num_pad = (k - (T % k)) % k
+ pad = torch.zeros((num_pad, 1, H, W), device=mask.device, dtype=mask.dtype)
+ mask = torch.cat([pad, mask], dim=0)
+
+
+ # [T,1,H,W] -> [1,1,T,H,W]
+ x = mask.permute(1, 0, 2, 3).unsqueeze(0)
+ if k > 1:
+ # Max-pool over time with kernel=stride=k (no spatial pooling)
+ pooled = F.max_pool3d(x, kernel_size=(k, 1, 1), stride=(k, 1, 1))
+ else:
+ pooled = x
+
+ # Up-sample spatially to match latent spatial resolution
+ s = self.vae_scale_factor_spatial
+ H_latent = pooled.shape[-2] // s
+ W_latent = pooled.shape[-1] // s
+ pooled = F.interpolate(pooled, size=(pooled.shape[2], H_latent, W_latent), mode="nearest")
+
+ # Back to [1, T_latent, 1, H, W]
+ latent_mask = pooled.permute(0, 2, 1, 3, 4)
+
+ return latent_mask
+
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
+ height: int = 576,
+ width: int = 1024,
+ num_frames: Optional[int] = None,
+ num_inference_steps: int = 25,
+ sigmas: Optional[List[float]] = None,
+ min_guidance_scale: float = 1.0,
+ max_guidance_scale: float = 3.0,
+ fps: int = 7,
+ motion_bucket_id: int = 127,
+ noise_aug_strength: float = 0.02,
+ decode_chunk_size: Optional[int] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ return_dict: bool = True,
+ motion_signal_video_path: Optional[str] = None,
+ motion_signal_mask_path: Optional[str] = None,
+ tweak_index: int = 0,
+ tstrong_index: int = 0
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
+ Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
+ 1]`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_frames (`int`, *optional*):
+ The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
+ `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
+ num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
+ fps (`int`, *optional*, defaults to 7):
+ Frames per second. The rate at which the generated images shall be exported to a video after
+ generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
+ motion_bucket_id (`int`, *optional*, defaults to 127):
+ Used for conditioning the amount of motion for the generation. The higher the number the more motion
+ will be in the video.
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
+ The amount of noise added to the init image, the higher it is the less the video will look like the
+ init image. Increase it for more motion.
+ decode_chunk_size (`int`, *optional*):
+ The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
+ expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
+ For lower memory usage, reduce `decode_chunk_size`.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ motion_signal_video_path (`str`):
+ Path to the video file containing the motion signal to guide the motion of the generated video.
+ It should be a crude version of the reference video, with pixels with motion dragged to their target.
+ motion_signal_mask_path (`str`):
+ Path to the mask video file containing the motion mask of TTM.
+ The mask should be a binary with the conditioning motion pixels being 1 and the rest being 0.
+ tweak_index (`int`):
+ The index of the tweak, from which the denoising process starts.
+ tstrong_index (`int`):
+ The index of the tweak, from which the denoising process starts in the motion conditioned region.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `pil`, `np` or `pt`.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that is called at the end of each denoising step during inference. The function is called
+ with the following arguments:
+ `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
+ `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
+ returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.Tensor`) is
+ returned.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width)
+
+ if motion_signal_mask_path is None:
+ raise ValueError("`motion_signal_mask_path` is required for TTM.")
+ if motion_signal_video_path is None:
+ raise ValueError("`motion_signal_video_path` is required for TTM.")
+
+ # 2. Define call parameters
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, list):
+ batch_size = len(image)
+ else:
+ batch_size = image.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ self._guidance_scale = max_guidance_scale
+
+ # 3. Encode input image
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
+
+ # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
+ fps = fps - 1
+
+ # 4. Encode input image using VAE
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device)
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
+ image = image + noise_aug_strength * noise
+
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float32)
+
+ image_latents = self._encode_vae_image(
+ image,
+ device=device,
+ num_videos_per_prompt=num_videos_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ )
+ image_latents = image_latents.to(image_embeddings.dtype)
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ # Repeat the image latents for each frame so we can concatenate them with the noise
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
+
+ # 5. Get Added Time IDs
+ added_time_ids = self._get_add_time_ids(
+ fps,
+ motion_bucket_id,
+ noise_aug_strength,
+ image_embeddings.dtype,
+ batch_size,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+ added_time_ids = added_time_ids.to(device)
+
+ # 6. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
+
+ # ---- Sanity checks for TTM indices (0 ≤ tstrong < tweak < num_steps) ----
+ if not (0 <= tstrong_index < num_inference_steps):
+ raise ValueError(f"tstrong_index must be in [0, {num_inference_steps-1}], got {tstrong_index}.")
+ if not (0 <= tweak_index < num_inference_steps):
+ raise ValueError(f"tweak_index must be in [0, {num_inference_steps-1}], got {tweak_index}.")
+ if not (tstrong_index > tweak_index):
+ raise ValueError(f"Require tweak_index < tstrong_index, got {tweak_index} >= {tstrong_index}.")
+
+ # 7. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_frames,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Prepare guidance scale
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
+ guidance_scale = guidance_scale.to(device, latents.dtype)
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
+
+ self._guidance_scale = guidance_scale
+
+ # 9. Initialize for TTM
+ ref_vid = load_video_to_tensor(motion_signal_video_path).to(device=device) # shape [1, C, T, H, W]
+ refB, refC, refT, refH, refW = ref_vid.shape
+
+ ref_vid = F.interpolate(
+ ref_vid.permute(0, 2, 1, 3, 4).reshape(refB*refT, refC, refH, refW),
+ size=(height, width), mode="bicubic", align_corners=True,
+ ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
+
+ ref_vid = self.video_processor.normalize(ref_vid.to(dtype=self.vae.dtype)) # Normalize and convert dtype for VAE encoding
+
+ if num_frames < refT:
+ logger.warning(f"num_frames ({num_frames}) < input frames ({refT}); trimming reference video.")
+ ref_vid = ref_vid[:, :, :num_frames]
+ elif num_frames > refT:
+ raise ValueError(f"num_frames ({num_frames}) is greater than input frames ({refT}). This is not supported.")
+
+ ref_latents = self.encode_frames(ref_vid, num_frames, decode_chunk_size).detach()
+ ref_latents = ref_latents.to(dtype=latents.dtype, device=device)
+
+ if not hasattr(self, "vae_scale_factor_temporal"): # encode ref video to latents
+ if hasattr(self.vae, "scale_factor_temporal"):
+ self.vae_scale_factor_temporal = self.vae.scale_factor_temporal
+ else:
+ if ref_latents.shape[1] == num_frames:
+ self.vae_scale_factor_temporal = 1
+ else:
+ raise ValueError("Please configure the temporal scale factor of the VAE.")
+
+ self.vae_scale_factor_spatial = self.vae_scale_factor
+
+ ref_mask = load_video_to_tensor(motion_signal_mask_path).to(device=device) # shape [1, C, T, H, W]
+
+ mB, mC, mT, mH, mW = ref_mask.shape # do resizing with nearest neighbor to avoid interpolation artifacts
+ ref_mask = F.interpolate(
+ ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW),
+ size=(height, width), mode="nearest",
+ ).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4)
+ ref_mask = ref_mask[0].permute(1, 0, 2, 3).contiguous() # (1, C, T, H, W) -> (T, H, W, 1)
+ if ref_mask.shape[0] > num_frames:
+ print(f"Warning: num_frames ({num_frames}) is less than input mask frames ({mT}). Trimming to {num_frames}.")
+ ref_mask = ref_mask[:num_frames]
+ elif ref_mask.shape[0] < num_frames:
+ raise ValueError(f"num_frames ({num_frames}) is greater than input mask frames ({mT}). This is not supported.")
+ ref_mask = (ref_mask > 0.5).float().max(dim=1, keepdim=True)[0] # [T, 1, H, W]
+ motion_mask = self.convert_rgb_mask_to_latent_mask(ref_mask, False) # [1, T, 1, H, W]
+ motion_mask = motion_mask.to(dtype=latents.dtype)
+ background_mask = 1.0 - motion_mask
+
+ if tweak_index >= 0:
+ tweak = self.scheduler.timesteps[tweak_index]
+ tweak = torch.tensor([tweak], device=device)
+ fixed_noise = randn_tensor(ref_latents.shape,
+ generator=generator,
+ device=ref_latents.device,
+ dtype=ref_latents.dtype)
+ noisy_latents = self.scheduler.add_noise(ref_latents, fixed_noise, tweak)
+ latents = noisy_latents.to(dtype=latents.dtype, device=latents.device)
+ else:
+ tweak = torch.tensor(-1)
+ fixed_noise = randn_tensor(ref_latents.shape,
+ generator=generator,
+ device=ref_latents.device,
+ dtype=ref_latents.dtype)
+ tweak_index = 0
+
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=len(timesteps) - tweak_index) as progress_bar:
+ for i, t in enumerate(timesteps[tweak_index:]):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # Concatenate image_latents over channels dimension
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=image_embeddings,
+ added_time_ids=added_time_ids,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ # In between tweak and tstrong, replace mask with noisy reference latents
+ in_between_tweak_tstrong = (i+tweak_index) < tstrong_index
+
+ if in_between_tweak_tstrong:
+ if i+tweak_index+1 < len(timesteps):
+ prev_t = torch.tensor([timesteps[i+tweak_index+1]], device=device)
+ noisy_latents = self.scheduler.add_noise(ref_latents, fixed_noise, prev_t).to(dtype=latents.dtype, device=latents.device)
+ latents = latents * background_mask + noisy_latents * motion_mask
+ elif i+tweak_index+1 == len(timesteps):
+ latents = latents * background_mask + ref_latents * motion_mask
+ else:
+ raise ValueError(f"Unexpected timestep index {i+tweak_index+1} >= {len(timesteps)}")
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
+ frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
+ else:
+ frames = latents
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return frames
+
+ return StableVideoDiffusionPipelineOutput(
+ frames=frames)
diff --git a/pipelines/utils.py b/pipelines/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..24073c100fd4c474082452f031fede58834b5fb4
--- /dev/null
+++ b/pipelines/utils.py
@@ -0,0 +1,45 @@
+import logging
+import os
+from pathlib import Path
+from typing import Tuple
+import numpy as np
+import cv2
+import torch
+
+def validate_inputs(image_path: str, mask_path: str, motion_path: str) -> None:
+ for p in [image_path, mask_path, motion_path]:
+ if not Path(p).exists():
+ raise FileNotFoundError(f"Required file not found: {p}")
+
+def compute_hw_from_area(
+ image_height: int,
+ image_width: int,
+ max_area: int,
+ mod_value: int,
+) -> Tuple[int, int]:
+ """Compute (height, width) with same math and rounding as original."""
+ aspect_ratio = image_height / image_width
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ return int(height), int(width)
+
+
+def load_video_to_tensor(video_path):
+ """Returns a video tensor from a video file. shape [1, T, C, H, W], [0, 1] range."""
+ # load video
+ cap = cv2.VideoCapture(video_path)
+ frames = []
+ while 1:
+ ret, frame = cap.read()
+ if not ret:
+ break
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frames.append(frame)
+ cap.release()
+ # Convert frames to tensor, shape [T, H, W, C], [0, 1] range
+ frames = np.array(frames)
+
+ video_tensor = torch.tensor(frames)
+ video_tensor = video_tensor.permute(0, 3, 1, 2).float() / 255.0
+ video_tensor = video_tensor.unsqueeze(0).permute(0, 2, 1, 3, 4)
+ return video_tensor
diff --git a/pipelines/wan_pipeline.py b/pipelines/wan_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf213b314244006b6464faf200ec28ec7ec45bf1
--- /dev/null
+++ b/pipelines/wan_pipeline.py
@@ -0,0 +1,559 @@
+# Copyright 2025 Noam Rotstein
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Adapted from Hugging Face Diffusers (Apache-2.0):
+# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
+
+
+try:
+ import html
+ from typing import Any, Callable, Dict, List, Optional, Union
+ import torch
+ from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+ from diffusers.image_processor import PipelineImageInput
+ from diffusers.models import AutoencoderKLWan, WanTransformer3DModel
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+ from diffusers.utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+ from diffusers.utils.torch_utils import randn_tensor
+ from diffusers.video_processor import VideoProcessor
+ from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
+ from diffusers.pipelines.wan.pipeline_wan_i2v import retrieve_latents, WanImageToVideoPipeline
+
+ import torch.nn.functional as F
+ from pipelines.utils import load_video_to_tensor
+
+ if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+ else:
+ XLA_AVAILABLE = False
+except ImportError as e:
+ raise ImportError(f"Required module not found: {e}. Please install it before running this script. "
+ f"For installation instructions, see: https://github.com/Wan-Video/Wan2.2")
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# after logger/is_ftfy_available
+_ftfy = None
+if is_ftfy_available():
+ import ftfy as _ftfy
+
+
+EXAMPLE_DOC_STRING = """
+"""
+
+
+class WanImageToVideoTTMPipeline(WanImageToVideoPipeline):
+ r"""
+ Pipeline for image-to-video generation using Wan with Time-To-Move (TTM) conditioning.
+ This model inherits from [`WanImageToVideoPipeline`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ image_processor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModel = None,
+ transformer: WanTransformer3DModel = None,
+ transformer_2: WanTransformer3DModel = None,
+ boundary_ratio: Optional[float] = None,
+ expand_timesteps: bool = False,
+ ):
+ super().__init__(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ transformer_2=transformer_2,
+ boundary_ratio=boundary_ratio,
+ expand_timesteps=expand_timesteps,
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ transformer_2=transformer_2,
+ )
+ self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.image_processor = image_processor
+
+
+ def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W'].
+ T_latent groups frames by the temporal VAE downsample factor k = vae_scale_factor_temporal:
+ [0], [1..k], [k+1..2k], ...
+ """
+
+ k = self.vae_scale_factor_temporal
+ mask0 = mask[0:1] # [1,1,H,W]
+ mask1 = mask[1::k] # [T'-1,1,H,W]
+ sampled = torch.cat([mask0, mask1], dim=0) # [T',1,H,W]
+ pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0)
+
+ # Up-sample spatially to match latent spatial resolution
+ spatial_downsample = self.vae_scale_factor_spatial
+ H_latent = pooled.shape[-2] // spatial_downsample
+ W_latent = pooled.shape[-1] // spatial_downsample
+ pooled = F.interpolate(pooled, size=(pooled.shape[2], H_latent, W_latent), mode="nearest")
+
+ # Back to [1, T_latent, 1, H, W]
+ latent_mask = pooled.permute(0, 2, 1, 3, 4)
+
+ return latent_mask
+
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ guidance_scale_2: Optional[float] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ motion_signal_video_path: Optional[str] = None,
+ motion_signal_mask_path: Optional[str] = None,
+ tweak_index: int = 0,
+ tstrong_index: int = 0
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `480`):
+ The height of the generated video.
+ width (`int`, defaults to `832`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
+ and the pipeline's `boundary_ratio` are not None.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+ motion_signal_video_path (`str`):
+ Path to the video file containing the motion signal to guide the motion of the generated video.
+ It should be a crude version of the reference video, with pixels with motion dragged to their target.
+ motion_signal_mask_path (`str`):
+ Path to the mask video file containing the motion mask of TTM.
+ The mask should be a binary with the conditioning motion pixels being 1 and the rest being 0.
+ tweak_index (`int`):
+ The index of the tweak, from which the denoising process starts.
+ tstrong_index (`int`):
+ The index of the tweak, from which the denoising process starts in the motion conditioned region.
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ guidance_scale_2,
+ )
+
+ if motion_signal_video_path is None:
+ raise ValueError("`motion_signal_video_path` must be provided for TTM.")
+ if motion_signal_mask_path is None:
+ raise ValueError("`motion_signal_mask_path` must be provided for TTM.")
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Encode image embedding
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # only wan 2.1 i2v transformer accepts image_embeds
+ if self.transformer is not None and self.transformer.config.image_dim is not None:
+ if image_embeds is None:
+ if last_image is None:
+ image_embeds = self.encode_image(image, device)
+ else:
+ image_embeds = self.encode_image([image, last_image], device)
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ tweak_index = int(tweak_index)
+ tstrong_index = int(tstrong_index)
+
+ if tweak_index < -1:
+ raise ValueError(f"`tweak_index` ({tweak_index}) must be >= -1.")
+ if tweak_index >= len(timesteps):
+ raise ValueError(f"`tweak_index` ({tweak_index}) must be < {len(timesteps)}.")
+
+ if tstrong_index < 0:
+ raise ValueError(f"`tstrong_index` ({tstrong_index}) must be >= 0.")
+ if tstrong_index >= len(timesteps):
+ raise ValueError(f"`tstrong_index` ({tstrong_index}) must be < {len(timesteps)}.")
+ if tstrong_index < max(0, tweak_index):
+ raise ValueError(f"`tstrong_index` ({tstrong_index}) must be >= `tweak_index` ({tweak_index}).")
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.z_dim
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+ if last_image is not None:
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ latents_outputs = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ last_image,
+ )
+ if self.config.expand_timesteps:
+ latents, condition, first_frame_mask = latents_outputs
+ else:
+ latents, condition = latents_outputs
+
+ # 6. Initialize for TTM
+ ref_vid = load_video_to_tensor(motion_signal_video_path).to(device=device) # shape [1, C, T, H, W]
+ refB, refC, refT, refH, refW = ref_vid.shape
+
+ ref_vid = F.interpolate(
+ ref_vid.permute(0, 2, 1, 3, 4).reshape(refB*refT, refC, refH, refW),
+ size=(height, width), mode="bicubic", align_corners=True,
+ ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4)
+
+ ref_vid = self.video_processor.normalize(ref_vid.to(dtype=self.vae.dtype)) # [1, C, T, H, W]
+ ref_latents = retrieve_latents(self.vae.encode(ref_vid), sample_mode="argmax") # [1, z, T', H', W']
+ latents_mean = torch.tensor(self.vae.config.latents_mean)\
+ .view(1, self.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std)\
+ .view(1, self.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype)
+ ref_latents = (ref_latents - latents_mean) * latents_std
+
+
+ ref_mask = load_video_to_tensor(motion_signal_mask_path).to(device=device) # shape [1, C, T, H, W]
+ mB, mC, mT, mH, mW = ref_mask.shape
+ ref_mask = F.interpolate(
+ ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW),
+ size=(height, width), mode="nearest",
+ ).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4) # [1, C, T, H, W] -> [T, C, H, W]
+ mask_tc_hw = ref_mask[0].permute(1, 0, 2, 3).contiguous()
+
+ if mask_tc_hw.shape[0] > num_frames: # Align time dimension to num_frames
+ logger.warning("Mask has %d frames but num_frames=%d; trimming.", mask_tc_hw.shape[0], num_frames)
+ mask_tc_hw = mask_tc_hw[:num_frames]
+ elif mask_tc_hw.shape[0] < num_frames:
+ raise ValueError(f"num_frames ({num_frames}) is greater than mask frames ({mask_tc_hw.shape[0]}). "
+ "Please pad/extend your mask or lower num_frames.")
+
+ if mask_tc_hw.shape[1] > 1: # Reduce channels if needed -> [T,1,H,W], binarize once
+ mask_t1_hw = (mask_tc_hw > 0.5).any(dim=1, keepdim=True).float()
+ else:
+ mask_t1_hw = (mask_tc_hw > 0.5).float()
+
+ motion_mask = self.convert_rgb_mask_to_latent_mask(mask_t1_hw).permute(0, 2, 1, 3, 4).contiguous()
+ background_mask = 1.0 - motion_mask
+
+ if tweak_index >= 0:
+ tweak = timesteps[tweak_index]
+ fixed_noise = randn_tensor(
+ ref_latents.shape,
+ generator=generator,
+ device=ref_latents.device,
+ dtype=ref_latents.dtype,
+ )
+ tweak = torch.as_tensor(tweak, device=ref_latents.device, dtype=torch.long).view(1)
+ noisy_latents = self.scheduler.add_noise(ref_latents, fixed_noise, tweak.long())
+ latents = noisy_latents.to(dtype=latents.dtype, device=latents.device)
+ else:
+ tweak = torch.tensor(-1)
+ fixed_noise = randn_tensor(
+ ref_latents.shape,
+ generator=generator,
+ device=ref_latents.device,
+ dtype=ref_latents.dtype,
+ )
+ tweak_index = 0
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ if self.config.boundary_ratio is not None:
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
+ with self.progress_bar(total=len(timesteps) - tweak_index) as progress_bar:
+ for i, t in enumerate(timesteps[tweak_index:]):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ if boundary_timestep is None or t >= boundary_timestep:
+ # wan2.1 or high-noise stage in wan2.2
+ current_model = self.transformer
+ current_guidance_scale = guidance_scale
+ else:
+ # low-noise stage in wan2.2
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_scale_2
+
+ if self.config.expand_timesteps:
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(transformer_dtype)
+
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ with current_model.cache_context("cond"):
+ noise_pred = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ with current_model.cache_context("uncond"):
+ noise_uncond = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
+
+ # In between tweak and tstrong, replace mask with noisy reference latents
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+ in_between_tweak_tstrong = (i+tweak_index) < tstrong_index
+
+ if in_between_tweak_tstrong:
+ if i+tweak_index+1 < len(timesteps):
+ prev_t = timesteps[i+tweak_index+1]
+ prev_t = torch.as_tensor(prev_t, device=ref_latents.device, dtype=torch.long).view(1)
+ noisy_latents = self.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to(dtype=latents.dtype, device=latents.device)
+ latents = latents * background_mask + noisy_latents * motion_mask
+ else:
+ latents = latents * background_mask + ref_latents.to(dtype=latents.dtype, device=latents.device) * motion_mask
+
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - tweak_index - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if self.config.expand_timesteps:
+ latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
\ No newline at end of file
diff --git a/run_cog.py b/run_cog.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b24a78b81111f24c546726d06d0d6b69fa28df6
--- /dev/null
+++ b/run_cog.py
@@ -0,0 +1,92 @@
+try:
+ import argparse
+ import os
+ import torch
+ from pipelines.cog_pipeline import CogVideoXImageToVideoTTMPipeline
+ from diffusers.utils import export_to_video, load_image
+ from pathlib import Path
+except ImportError as e:
+ raise ImportError(f"Required module not found: {e}. Please install it before running this script. "
+ f"For installation instructions, see: https://github.com/zai-org/CogVideo")
+
+
+MODEL_ID = "THUDM/CogVideoX-5b-I2V"
+DTYPE = torch.bfloat16
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Run Wan Image to Video Pipeline")
+ parser.add_argument("--input-path", type=str, default="./examples/cutdrag_cog_Monkey", help="Path to input image")
+ parser.add_argument("--output-path", type=str, default="./outputs/output_cog_Monkey.mp4", help="Path to save output video")
+ parser.add_argument("--tweak-index", type=int, default=4, help="t weak timestep index- when to start denoising")
+ parser.add_argument("--tstrong-index", type=int, default=8, help="t strong timestep index- when to start denoising within the mask")
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of inference steps")
+ parser.add_argument("--num-frames", type=int, default=49, help="Number of frames to generate")
+ parser.add_argument("--guidance-scale", type=float, default=6.0, help="Guidance scale for generation")
+ return parser.parse_args()
+
+
+args = parse_args()
+
+image_path = os.path.join(args.input_path, "first_frame.png")
+motion_signal_mask_path = os.path.join(args.input_path, "mask.mp4")
+motion_signal_video_path = os.path.join(args.input_path, "motion_signal.mp4")
+prompt_path = os.path.join(args.input_path, "prompt.txt")
+
+num_inference_steps = args.num_inference_steps
+seed = args.seed
+tweak_index = args.tweak_index
+tstrong_index = args.tstrong_index
+num_frames = args.num_frames
+guidance_scale = args.guidance_scale
+output_path = args.output_path
+
+
+# make sure output directory exists
+Path(os.path.dirname(output_path) or ".").mkdir(parents=True, exist_ok=True)
+
+# -----------------------
+# Setup Pipeline
+# -----------------------
+def setup_cog_pipeline(model_id: str, dtype: torch.dtype):
+ pipe = CogVideoXImageToVideoTTMPipeline.from_pretrained(
+ model_id,
+ torch_dtype=dtype,
+ low_cpu_mem_usage=True,
+ device_map="balanced", # keep this
+ )
+ pipe.vae.enable_tiling() # pipe.enable_vae_slicing()
+ pipe.vae.enable_slicing() # pipe.enable_vae_tiling()
+ return pipe
+
+
+def main():
+ pipe = setup_cog_pipeline(MODEL_ID, DTYPE)
+ image = load_image(image_path)
+ # Load prompt (unchanged)
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ prompt = f.read().strip()
+ prompt = (prompt)
+
+ result = pipe(
+ [image],
+ [prompt],
+ generator=torch.Generator().manual_seed(seed),
+ num_inference_steps=num_inference_steps,
+ motion_signal_video_path=motion_signal_video_path,
+ motion_signal_mask_path=motion_signal_mask_path,
+ tweak_index=tweak_index,
+ tstrong_index=tstrong_index,
+ num_frames=num_frames,
+ guidance_scale=guidance_scale
+ )
+
+ frames = result.frames[0]
+ Path(os.path.dirname(output_path) or ".").mkdir(parents=True, exist_ok=True)
+ export_to_video(frames, output_path, fps=8)
+ print(f"Video saved to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/run_svd.py b/run_svd.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c367fdbcb52a8233122d5539d7a1052557d94e5
--- /dev/null
+++ b/run_svd.py
@@ -0,0 +1,84 @@
+try:
+ import argparse
+ import os
+ import torch
+ from pipelines.svd_pipeline import StableVideoDiffusionTTMPipeline
+ from diffusers.utils import export_to_video, load_image
+ from PIL import Image
+ import json
+ from pathlib import Path
+except ImportError as e:
+ raise ImportError(f"Required module not found: {e}. Please install it before running this script. "
+ f"For installation instructions, see: https://github.com/Stability-AI/generative-models")
+
+
+MODEL_ID = "stabilityai/stable-video-diffusion-img2vid-xt"
+DTYPE = torch.float16
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Run Wan Image to Video Pipeline")
+ parser.add_argument("--input-path", type=str, default="./examples/cutdrag_cog_Monkey", help="Path to input image")
+ parser.add_argument("--output-path", type=str, default="./outputs/output_cog_Monkey.mp4", help="Path to save output video")
+ parser.add_argument("--tweak-index", type=int, default=16, help="t weak timestep index- when to start denoising")
+ parser.add_argument("--tstrong-index", type=int, default=21, help="t strong timestep index- when to start denoising within the mask")
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of inference steps")
+ parser.add_argument("--num-frames", type=int, default=21, help="Number of frames to generate")
+ parser.add_argument("--motion_bucket_id", type=int, default=17, help="Amount of motion to condition on")
+ return parser.parse_args()
+
+
+args = parse_args()
+
+image_path = os.path.join(args.input_path, "first_frame.png")
+motion_signal_mask_path = os.path.join(args.input_path, "mask.mp4")
+motion_signal_video_path = os.path.join(args.input_path, "motion_signal.mp4")
+
+num_inference_steps = args.num_inference_steps
+seed = args.seed
+tweak_index = args.tweak_index
+tstrong_index = args.tstrong_index
+num_frames = args.num_frames
+motion_bucket_id = args.motion_bucket_id
+output_path = args.output_path
+
+
+# make sure output directory exists
+Path(os.path.dirname(output_path) or ".").mkdir(parents=True, exist_ok=True)
+
+# -----------------------
+# Setup Pipeline
+# -----------------------
+def setup_svd_pipeline(model_id: str, dtype: torch.dtype):
+ pipe = StableVideoDiffusionTTMPipeline.from_pretrained(
+ model_id,
+ torch_dtype=dtype,
+ device_map="balanced", # keep this
+ variant="fp16"
+ )
+ return pipe
+
+
+def main():
+ pipe = setup_svd_pipeline(MODEL_ID, DTYPE)
+ image = load_image(image_path)
+ result = pipe(
+ [image],
+ generator=torch.Generator().manual_seed(seed),
+ motion_bucket_id=motion_bucket_id,
+ num_inference_steps=num_inference_steps,
+ motion_signal_video_path=motion_signal_video_path,
+ motion_signal_mask_path=motion_signal_mask_path,
+ tweak_index=tweak_index,
+ tstrong_index=tstrong_index,
+ num_frames=num_frames
+ )
+
+ frames = result.frames[0]
+ Path(os.path.dirname(output_path) or ".").mkdir(parents=True, exist_ok=True)
+ export_to_video(frames, output_path, fps=7)
+ print(f"Video saved to {output_path}")
+
+if __name__ == "__main__":
+ main()
+
diff --git a/run_wan.py b/run_wan.py
new file mode 100644
index 0000000000000000000000000000000000000000..36f1da5ba689fd28352488273b3027b53d3c38fb
--- /dev/null
+++ b/run_wan.py
@@ -0,0 +1,119 @@
+try:
+ import os
+ from pathlib import Path
+ import torch
+ from diffusers.utils import export_to_video, load_image
+ from pipelines.wan_pipeline import WanImageToVideoTTMPipeline
+ from pipelines.utils import (
+ validate_inputs,
+ compute_hw_from_area,
+ )
+ import argparse
+except ImportError as e:
+ raise ImportError(f"Required module not found: {e}. Please install it before running this script. "
+ f"For installation instructions, see: https://github.com/Wan-Video/Wan2.2")
+
+MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
+DTYPE = torch.bfloat16
+
+# -----------------------
+# Argument Parser
+# -----------------------
+def parse_args():
+ parser = argparse.ArgumentParser(description="Run Wan Image to Video Pipeline")
+ parser.add_argument("--input-path", type=str, default="./examples/wan_monkey", help="Path to input image")
+ parser.add_argument("--output-path", type=str, default="./outputs/output_wan_monkey.mp4", help="Path to save output video")
+ parser.add_argument("--negative-prompt", type=str, default=(
+ "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,"
+ "低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,"
+ "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+ ), help="Default negative prompt in Wan2.2")
+ parser.add_argument("--tweak-index", type=int, default=3, help="t weak timestep index- when to start denoising")
+ parser.add_argument("--tstrong-index", type=int, default=6, help="t strong timestep index- when to start denoising within the mask")
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of inference steps")
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use")
+ parser.add_argument("--max-area", type=int, default=480 * 832, help="Maximum area for resizing")
+ parser.add_argument("--num-frames", type=int, default=81, help="Number of frames to generate")
+ parser.add_argument("--guidance-scale", type=float, default=3.5, help="Guidance scale for generation")
+ return parser.parse_args()
+
+
+args = parse_args()
+image_path = os.path.join(args.input_path, "first_frame.png")
+motion_signal_mask_path = os.path.join(args.input_path, "mask.mp4")
+motion_signal_video_path = os.path.join(args.input_path, "motion_signal.mp4")
+prompt_path = os.path.join(args.input_path, "prompt.txt")
+
+output_path = args.output_path
+negative_prompt = args.negative_prompt
+tweak_index = args.tweak_index
+tstrong_index = args.tstrong_index
+num_inference_steps = args.num_inference_steps
+seed = args.seed
+device = args.device
+max_area = args.max_area
+num_frames = args.num_frames
+guidance_scale = args.guidance_scale
+
+# make sure output directory exists
+Path(os.path.dirname(output_path) or ".").mkdir(parents=True, exist_ok=True)
+
+# -----------------------
+# Setup Pipeline
+# -----------------------
+def setup_wan_pipeline(model_id: str, dtype: torch.dtype, device: str):
+ pipe = WanImageToVideoTTMPipeline.from_pretrained(model_id, torch_dtype=dtype)
+ pipe.vae.enable_tiling()
+ pipe.vae.enable_slicing()
+ pipe.to(device)
+ return pipe
+
+
+# -----------------------
+# Main (same functional steps)
+# -----------------------
+def main():
+ validate_inputs(image_path, motion_signal_mask_path, motion_signal_video_path)
+ pipe = setup_wan_pipeline(MODEL_ID, DTYPE, device)
+
+ # Load and resize image (unchanged logic)
+ image = load_image(image_path)
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ height, width = compute_hw_from_area(image.height, image.width, max_area, mod_value)
+ image = image.resize((width, height))
+
+ # Load prompt (unchanged)
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ prompt = f.read().strip()
+ prompt = (prompt)
+
+ # Generator / seed (unchanged)
+ gen_device = device if device.startswith("cuda") else "cpu"
+ generator = torch.Generator(device=gen_device).manual_seed(seed)
+
+ with torch.inference_mode():
+ result = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ motion_signal_video_path=motion_signal_video_path,
+ motion_signal_mask_path=motion_signal_mask_path,
+ tweak_index=tweak_index,
+ tstrong_index=tstrong_index,
+ )
+
+ frames = result.frames[0]
+ Path(os.path.dirname(output_path) or ".").mkdir(parents=True, exist_ok=True)
+ export_to_video(frames, output_path, fps=16)
+ print(f"Video saved to {output_path}")
+
+
+if __name__ == "__main__":
+ main()