Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import os | |
| import numpy as np | |
| import shutil | |
| import torch | |
| from diffusers import FluxKontextPipeline | |
| import cv2 | |
| from loguru import logger | |
| from PIL import Image | |
| try: | |
| import moviepy.editor as mpy | |
| except: | |
| import moviepy as mpy | |
| from decord import VideoReader | |
| from pose2d import Pose2d | |
| from pose2d_utils import AAPoseMeta | |
| from utils import resize_by_area, get_frame_indices, padding_resize, get_face_bboxes, get_aug_mask, get_mask_body_img | |
| from human_visualization import draw_aapose_by_meta_new | |
| from retarget_pose import get_retarget_pose | |
| from sam2.build_sam import build_sam2, build_sam2_video_predictor | |
| def get_frames(video_path, resolution_area, fps=30): | |
| video_reader = VideoReader(video_path) | |
| frame_num = len(video_reader) | |
| video_fps = video_reader.get_avg_fps() | |
| # TODO: Maybe we can switch to PyAV later, which can get accurate frame num | |
| duration = video_reader.get_frame_timestamp(-1)[-1] | |
| expected_frame_num = int(duration * video_fps + 0.5) | |
| ratio = abs((frame_num - expected_frame_num)/frame_num) | |
| if ratio > 0.1: | |
| print("Warning: The difference between the actual number of frames and the expected number of frames is two large") | |
| frame_num = expected_frame_num | |
| if fps == -1: | |
| fps = video_fps | |
| target_num = int(frame_num / video_fps * fps) | |
| idxs = get_frame_indices(frame_num, video_fps, target_num, fps) | |
| frames = video_reader.get_batch(idxs).asnumpy() | |
| frames = [resize_by_area(frame, resolution_area[0] * resolution_area[1], divisor=16) for frame in frames] | |
| return frames | |
| def quantize_mask_blocky(mask, block_w=16, block_h=16, occupancy=0.15): | |
| """ | |
| Convert a binary mask to a blocky (quantized) mask. | |
| - block_w, block_h: target block size in pixels | |
| - occupancy: fraction [0..1] of foreground within a block to turn it on | |
| """ | |
| m = (mask > 0).astype(np.uint8) | |
| H, W = m.shape[:2] | |
| # compute “block grid” size | |
| grid_w = max(1, int(np.ceil(W / block_w))) | |
| grid_h = max(1, int(np.ceil(H / block_h))) | |
| # downsample to grid using area interpolation (captures occupancy) | |
| small = cv2.resize(m, (grid_w, grid_h), interpolation=cv2.INTER_AREA) | |
| # threshold by occupancy (values now in [0,1] if source was 0/1) | |
| small_q = (small >= occupancy).astype(np.uint8) | |
| # upsample back with nearest (keeps sharp blocks) | |
| blocky = cv2.resize(small_q, (W, H), interpolation=cv2.INTER_NEAREST) | |
| return blocky | |
| class ProcessPipeline(): | |
| def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path): | |
| self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path) | |
| if sam_checkpoint_path is not None: | |
| model_cfg = sam_checkpoint_path[1] | |
| self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path[0], device="cuda") | |
| if flux_kontext_path is not None: | |
| self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda") | |
| def __call__(self, video_path, | |
| refer_image_path, | |
| output_path, | |
| resolution_area=[1280, 720], | |
| fps=30, iterations=3, | |
| k=7, | |
| w_len=1, | |
| h_len=1, | |
| retarget_flag=False, | |
| use_flux=False, | |
| replace_flag=False, | |
| pts_by_frame=None, | |
| lbs_by_frame=None): | |
| if replace_flag: | |
| frames = get_frames(video_path, resolution_area, fps) | |
| height, width = frames[0].shape[:2] | |
| if not pts_by_frame and not lbs_by_frame: | |
| ############################################################################ | |
| tpl_pose_metas = self.pose2d(frames) | |
| face_images = [] | |
| for idx, meta in enumerate(tpl_pose_metas): | |
| face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3, | |
| image_shape=(frames[0].shape[0], frames[0].shape[1])) | |
| x1, x2, y1, y2 = face_bbox_for_image | |
| face_image = frames[idx][y1:y2, x1:x2] | |
| face_image = cv2.resize(face_image, (512, 512)) | |
| face_images.append(face_image) | |
| logger.info(f"Processing reference image: {refer_image_path}") | |
| refer_img = cv2.imread(refer_image_path) | |
| src_ref_path = os.path.join(output_path, 'src_ref.png') | |
| shutil.copy(refer_image_path, src_ref_path) | |
| refer_img = refer_img[..., ::-1] | |
| refer_img = padding_resize(refer_img, height, width) | |
| logger.info(f"Processing template video: {video_path}") | |
| tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas] | |
| cond_images = [] | |
| for idx, meta in enumerate(tpl_retarget_pose_metas): | |
| canvas = np.zeros_like(refer_img) | |
| conditioning_image = draw_aapose_by_meta_new(canvas, meta) | |
| cond_images.append(conditioning_image) | |
| ############################################################################ | |
| ############################################################################ | |
| masks = self.get_mask_from_face_bbox(frames, 400, tpl_pose_metas) | |
| bg_images = [] | |
| aug_masks = [] | |
| for frame, mask in zip(frames, masks): | |
| if iterations > 0: | |
| _, each_mask = get_mask_body_img(frame, mask, iterations=iterations, k=k) | |
| each_aug_mask = get_aug_mask(each_mask, w_len=w_len, h_len=h_len) | |
| else: | |
| each_aug_mask = mask | |
| each_bg_image = frame * (1 - each_aug_mask[:, :, None]) | |
| bg_images.append(each_bg_image) | |
| aug_masks.append(each_aug_mask) | |
| ############################################################################ | |
| else: | |
| ############################################################################ | |
| masks = self.get_mask_from_face_bbox_v2(frames, pts_by_frame=pts_by_frame, lbs_by_frame=lbs_by_frame) | |
| bg_images = [] | |
| aug_masks = [] | |
| for frame, mask in zip(frames, masks): | |
| if iterations > 0: | |
| _, each_mask = get_mask_body_img(frame, mask, iterations=iterations, k=k) | |
| # each_aug_mask = get_aug_mask(each_mask, w_len=w_len, h_len=h_len) | |
| each_aug_mask = quantize_mask_blocky(each_mask, block_w=16, block_h=16, occupancy=0.15) | |
| # each_aug_mask = each_mask | |
| else: | |
| each_aug_mask = mask | |
| each_bg_image = frame * (1 - each_aug_mask[:, :, None]) | |
| bg_images.append(each_bg_image) | |
| aug_masks.append(each_aug_mask) | |
| ############################################################################ | |
| ############################################################################ | |
| tpl_pose_metas = self.pose2d( | |
| frames, | |
| bbx=masks, # your per-frame masks list/array | |
| ) | |
| face_images = [] | |
| for idx, meta in enumerate(tpl_pose_metas): | |
| face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3, | |
| image_shape=(frames[0].shape[0], frames[0].shape[1])) | |
| x1, x2, y1, y2 = face_bbox_for_image | |
| face_image = frames[idx][y1:y2, x1:x2] | |
| face_image = cv2.resize(face_image, (512, 512)) | |
| face_images.append(face_image) | |
| logger.info(f"Processing reference image: {refer_image_path}") | |
| refer_img = cv2.imread(refer_image_path) | |
| src_ref_path = os.path.join(output_path, 'src_ref.png') | |
| shutil.copy(refer_image_path, src_ref_path) | |
| refer_img = refer_img[..., ::-1] | |
| refer_img = padding_resize(refer_img, height, width) | |
| logger.info(f"Processing template video: {video_path}") | |
| tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas] | |
| cond_images = [] | |
| for idx, meta in enumerate(tpl_retarget_pose_metas): | |
| canvas = np.zeros_like(refer_img) | |
| conditioning_image = draw_aapose_by_meta_new(canvas, meta) | |
| cond_images.append(conditioning_image) | |
| ############################################################################ | |
| src_face_path = os.path.join(output_path, 'src_face.mp4') | |
| mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path, logger=None) | |
| src_pose_path = os.path.join(output_path, 'src_pose.mp4') | |
| mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path, logger=None) | |
| src_bg_path = os.path.join(output_path, 'src_bg.mp4') | |
| mpy.ImageSequenceClip(bg_images, fps=fps).write_videofile(src_bg_path, logger=None) | |
| aug_masks_new = [np.stack([mask * 255, mask * 255, mask * 255], axis=2) for mask in aug_masks] | |
| src_mask_path = os.path.join(output_path, 'src_mask.mp4') | |
| mpy.ImageSequenceClip(aug_masks_new, fps=fps).write_videofile(src_mask_path, logger=None) | |
| return True | |
| else: | |
| logger.info(f"Processing reference image: {refer_image_path}") | |
| refer_img = cv2.imread(refer_image_path) | |
| src_ref_path = os.path.join(output_path, 'src_ref.png') | |
| shutil.copy(refer_image_path, src_ref_path) | |
| refer_img = refer_img[..., ::-1] | |
| refer_img = resize_by_area(refer_img, resolution_area[0] * resolution_area[1], divisor=16) | |
| refer_pose_meta = self.pose2d([refer_img])[0] | |
| logger.info(f"Processing template video: {video_path}") | |
| video_reader = VideoReader(video_path) | |
| frame_num = len(video_reader) | |
| video_fps = video_reader.get_avg_fps() | |
| # TODO: Maybe we can switch to PyAV later, which can get accurate frame num | |
| duration = video_reader.get_frame_timestamp(-1)[-1] | |
| expected_frame_num = int(duration * video_fps + 0.5) | |
| ratio = abs((frame_num - expected_frame_num)/frame_num) | |
| if ratio > 0.1: | |
| print("Warning: The difference between the actual number of frames and the expected number of frames is two large") | |
| frame_num = expected_frame_num | |
| if fps == -1: | |
| fps = video_fps | |
| target_num = int(frame_num / video_fps * fps) | |
| idxs = get_frame_indices(frame_num, video_fps, target_num, fps) | |
| frames = video_reader.get_batch(idxs).asnumpy() | |
| logger.info(f"Processing pose meta") | |
| tpl_pose_meta0 = self.pose2d(frames[:1])[0] | |
| tpl_pose_metas = self.pose2d(frames) | |
| face_images = [] | |
| for idx, meta in enumerate(tpl_pose_metas): | |
| face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3, | |
| image_shape=(frames[0].shape[0], frames[0].shape[1])) | |
| x1, x2, y1, y2 = face_bbox_for_image | |
| face_image = frames[idx][y1:y2, x1:x2] | |
| face_image = cv2.resize(face_image, (512, 512)) | |
| face_images.append(face_image) | |
| if retarget_flag: | |
| if use_flux: | |
| tpl_prompt, refer_prompt = self.get_editing_prompts(tpl_pose_metas, refer_pose_meta) | |
| refer_input = Image.fromarray(refer_img) | |
| refer_edit = self.flux_kontext( | |
| image=refer_input, | |
| height=refer_img.shape[0], | |
| width=refer_img.shape[1], | |
| prompt=refer_prompt, | |
| guidance_scale=2.5, | |
| num_inference_steps=28, | |
| ).images[0] | |
| refer_edit = Image.fromarray(padding_resize(np.array(refer_edit), refer_img.shape[0], refer_img.shape[1])) | |
| refer_edit_path = os.path.join(output_path, 'refer_edit.png') | |
| refer_edit.save(refer_edit_path) | |
| refer_edit_pose_meta = self.pose2d([np.array(refer_edit)])[0] | |
| tpl_img = frames[1] | |
| tpl_input = Image.fromarray(tpl_img) | |
| tpl_edit = self.flux_kontext( | |
| image=tpl_input, | |
| height=tpl_img.shape[0], | |
| width=tpl_img.shape[1], | |
| prompt=tpl_prompt, | |
| guidance_scale=2.5, | |
| num_inference_steps=28, | |
| ).images[0] | |
| tpl_edit = Image.fromarray(padding_resize(np.array(tpl_edit), tpl_img.shape[0], tpl_img.shape[1])) | |
| tpl_edit_path = os.path.join(output_path, 'tpl_edit.png') | |
| tpl_edit.save(tpl_edit_path) | |
| tpl_edit_pose_meta0 = self.pose2d([np.array(tpl_edit)])[0] | |
| tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tpl_edit_pose_meta0, refer_edit_pose_meta) | |
| else: | |
| tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, None, None) | |
| else: | |
| tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas] | |
| cond_images = [] | |
| for idx, meta in enumerate(tpl_retarget_pose_metas): | |
| if retarget_flag: | |
| canvas = np.zeros_like(refer_img) | |
| conditioning_image = draw_aapose_by_meta_new(canvas, meta) | |
| else: | |
| canvas = np.zeros_like(frames[0]) | |
| conditioning_image = draw_aapose_by_meta_new(canvas, meta) | |
| conditioning_image = padding_resize(conditioning_image, refer_img.shape[0], refer_img.shape[1]) | |
| cond_images.append(conditioning_image) | |
| src_face_path = os.path.join(output_path, 'src_face.mp4') | |
| mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path, logger=None) | |
| src_pose_path = os.path.join(output_path, 'src_pose.mp4') | |
| mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path, logger=None) | |
| return True | |
| def get_editing_prompts(self, tpl_pose_metas, refer_pose_meta): | |
| arm_visible = False | |
| leg_visible = False | |
| for tpl_pose_meta in tpl_pose_metas: | |
| tpl_keypoints = tpl_pose_meta['keypoints_body'] | |
| if tpl_keypoints[3].all() != 0 or tpl_keypoints[4].all() != 0 or tpl_keypoints[6].all() != 0 or tpl_keypoints[7].all() != 0: | |
| if (tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75) or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75) or \ | |
| (tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75) or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75): | |
| arm_visible = True | |
| if tpl_keypoints[9].all() != 0 or tpl_keypoints[12].all() != 0 or tpl_keypoints[10].all() != 0 or tpl_keypoints[13].all() != 0: | |
| if (tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75) or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75) or \ | |
| (tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75) or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75): | |
| leg_visible = True | |
| if arm_visible and leg_visible: | |
| break | |
| if leg_visible: | |
| if tpl_pose_meta['width'] > tpl_pose_meta['height']: | |
| tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image." | |
| else: | |
| tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image." | |
| if refer_pose_meta['width'] > refer_pose_meta['height']: | |
| refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image." | |
| else: | |
| refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image." | |
| elif arm_visible: | |
| if tpl_pose_meta['width'] > tpl_pose_meta['height']: | |
| tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image." | |
| else: | |
| tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image." | |
| if refer_pose_meta['width'] > refer_pose_meta['height']: | |
| refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image." | |
| else: | |
| refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image." | |
| else: | |
| tpl_prompt = "Change the person to face forward." | |
| refer_prompt = "Change the person to face forward." | |
| return tpl_prompt, refer_prompt | |
| def get_mask_from_face_bbox_v2( | |
| self, | |
| frames, | |
| pts_by_frame: dict[int, list[list[float]]] | None = None, | |
| lbs_by_frame: dict[int, list[int | float]] | None = None, | |
| ): | |
| """ | |
| Args: | |
| frames: list/array of HxWx3 uint8 frames. | |
| pts_by_frame: {frame_idx: [[x,y], ...], ...} | |
| labels_by_frame: {frame_idx: [0/1,...], ...} | |
| Returns: | |
| all_mask: list[np.uint8 mask] of length len(frames), each (H, W) in {0,1} | |
| """ | |
| print(f"lbs_by_frame:{lbs_by_frame}") | |
| print(f"pts_by_frame:{pts_by_frame}") | |
| # --- safety & normalization --- | |
| if pts_by_frame is None: | |
| pts_by_frame = {} | |
| if lbs_by_frame is None: | |
| lbs_by_frame = {} | |
| # normalize keys to int (in case they arrived as strings) | |
| pts_by_frame = {int(k): v for k, v in pts_by_frame.items()} | |
| lbs_by_frame = {int(k): v for k, v in lbs_by_frame.items()} | |
| H, W = frames[0].shape[:2] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| with torch.autocast(device_type=device, dtype=torch.bfloat16 if device == "cuda" else torch.float16): | |
| # 1) init SAM2 video predictor state | |
| inference_state = self.predictor.init_state(images=np.array(frames), device=device) | |
| # 2) feed all per-frame clicks before propagating | |
| # We use the *same obj_id* (0) so all clicks describe one object, | |
| # no matter which frame they were added on. | |
| for fidx in sorted(pts_by_frame.keys()): | |
| pts = np.array(pts_by_frame.get(fidx, []), dtype=np.float32) | |
| lbs = np.array(lbs_by_frame.get(fidx, []), dtype=np.int32) | |
| if pts.size == 0: | |
| continue # nothing to add for this frame | |
| # (optional) sanity: make sure lens match | |
| if len(pts) != len(lbs): | |
| raise ValueError(f"Points/labels length mismatch at frame {fidx}: {len(pts)} vs {len(lbs)}") | |
| self.predictor.add_new_points( | |
| inference_state=inference_state, | |
| frame_idx=int(fidx), | |
| obj_id=0, | |
| points=pts, | |
| labels=lbs, | |
| ) | |
| # 3) propagate across the whole video | |
| video_segments = {} | |
| for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video( | |
| inference_state, start_frame_idx=0 | |
| ): | |
| # store boolean masks per object id for this frame | |
| video_segments[out_frame_idx] = { | |
| out_obj_id: (out_mask_logits[i] > 0.0).to("cpu").numpy() | |
| for i, out_obj_id in enumerate(out_obj_ids) | |
| } | |
| # 4) collect masks in order; fall back to zeros where predictor returned nothing | |
| all_mask = [] | |
| zero_mask = np.zeros((H, W), dtype=np.uint8) | |
| for out_frame_idx in range(len(frames)): | |
| if out_frame_idx in video_segments and len(video_segments[out_frame_idx]) > 0: | |
| mask = next(iter(video_segments[out_frame_idx].values())) | |
| if mask.ndim == 3: # (1, H, W) -> (H, W) | |
| mask = mask[0] | |
| mask = mask.astype(np.uint8) | |
| else: | |
| mask = zero_mask | |
| all_mask.append(mask) | |
| return all_mask | |
| def get_mask_from_face_bbox(self, frames, th_step, kp2ds_all): | |
| """ | |
| Build masks using a face bounding box per key frame (derived from keypoints_face), | |
| then propagate with SAM2 across each chunk of frames. | |
| """ | |
| H, W = frames[0].shape[:2] | |
| def _clip_box(x1, y1, x2, y2, W, H): | |
| x1 = max(0, min(int(x1), W - 1)) | |
| x2 = max(0, min(int(x2), W - 1)) | |
| y1 = max(0, min(int(y1), H - 1)) | |
| y2 = max(0, min(int(y2), H - 1)) | |
| if x2 <= x1: x2 = min(W - 1, x1 + 1) | |
| if y2 <= y1: y2 = min(H - 1, y1 + 1) | |
| return x1, y1, x2, y2 | |
| frame_num = len(frames) | |
| if frame_num < th_step: | |
| num_step = 1 | |
| else: | |
| num_step = (frame_num + th_step) // th_step | |
| all_mask = [] | |
| for step_idx in range(num_step): | |
| each_frames = frames[step_idx * th_step:(step_idx + 1) * th_step] | |
| kp2ds = kp2ds_all[step_idx * th_step:(step_idx + 1) * th_step] | |
| if len(each_frames) == 0: | |
| continue | |
| # pick a few key frames in this chunk | |
| key_frame_num = 4 if len(each_frames) > 4 else 1 | |
| key_frame_step = max(1, len(kp2ds) // key_frame_num) | |
| key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num] | |
| # compute face boxes on the selected key frames | |
| key_frame_boxes = [] | |
| for kfi in key_frame_index_list: | |
| meta = kp2ds[kfi] | |
| # get_face_bboxes returns (x1, x2, y1, y2) in your code | |
| x1, x2, y1, y2 = get_face_bboxes( | |
| meta['keypoints_face'][:, :2], | |
| scale=1.3, | |
| image_shape=(H, W) | |
| ) | |
| x1, y1, x2, y2 = _clip_box(x1, y1, x2, y2, W, H) | |
| key_frame_boxes.append(np.array([x1, y1, x2, y2], dtype=np.float32)) | |
| # init SAM2 for this chunk | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| inference_state = self.predictor.init_state(images=np.array(each_frames), device="cuda") | |
| self.predictor.reset_state(inference_state) | |
| ann_obj_id = 1 | |
| # seed with box prompts (preferred), else fall back to points | |
| for ann_frame_idx, box_xyxy in zip(key_frame_index_list, key_frame_boxes): | |
| used_box = False | |
| try: | |
| # If your predictor exposes a box API, this is ideal. | |
| _ = self.predictor.add_new_box( | |
| inference_state=inference_state, | |
| frame_idx=ann_frame_idx, | |
| obj_id=ann_obj_id, | |
| box=box_xyxy[None, :] # shape (1, 4) | |
| ) | |
| used_box = True | |
| except Exception: | |
| used_box = False | |
| if not used_box: | |
| # Fallback: sample a few positive points inside the box | |
| x1, y1, x2, y2 = box_xyxy.astype(int) | |
| cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 | |
| pts = np.array([ | |
| [cx, cy], | |
| [x1 + (x2 - x1) // 4, cy], | |
| [x2 - (x2 - x1) // 4, cy], | |
| [cx, y1 + (y2 - y1) // 4], | |
| [cx, y2 - (y2 - y1) // 4], | |
| ], dtype=np.int32) | |
| labels = np.ones(len(pts), dtype=np.int32) # 1 = positive | |
| _ = self.predictor.add_new_points( | |
| inference_state=inference_state, | |
| frame_idx=ann_frame_idx, | |
| obj_id=ann_obj_id, | |
| points=pts, | |
| labels=labels, | |
| ) | |
| # propagate across the chunk | |
| video_segments = {} | |
| for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state): | |
| video_segments[out_frame_idx] = { | |
| out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() | |
| for i, out_obj_id in enumerate(out_obj_ids) | |
| } | |
| # collect masks (single object id) | |
| for out_frame_idx in range(len(video_segments)): | |
| # (H, W) boolean/uint8 | |
| mask = next(iter(video_segments[out_frame_idx].values())) | |
| mask = mask[0].astype(np.uint8) | |
| all_mask.append(mask) | |
| return all_mask |