# Copyright (c) 2025. Your modifications here. # This file wraps and extends sam2.utils.misc for custom modifications. from sam2.utils import misc as sam2_misc from sam2.utils.misc import * from PIL import Image import numpy as np import torch from tqdm import tqdm import os import logging import torch from hydra import compose from hydra.utils import instantiate from omegaconf import OmegaConf from sam2.utils.misc import AsyncVideoFrameLoader, _load_img_as_tensor from sam2.build_sam import _load_checkpoint def _load_img_v2_as_tensor(img, image_size): img_pil = Image.fromarray(img.astype(np.uint8)) img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images img_np = img_np / 255.0 else: raise RuntimeError(f"Unknown image dtype: {img_np.dtype}") img = torch.from_numpy(img_np).permute(2, 0, 1) video_width, video_height = img_pil.size # the original video size return img, video_height, video_width def load_video_frames( video_path, image_size, offload_video_to_cpu, img_mean=(0.485, 0.456, 0.406), img_std=(0.229, 0.224, 0.225), async_loading_frames=False, frame_names=None, ): """ Load the video frames from a directory of JPEG files (".jpg" format). The frames are resized to image_size x image_size and are loaded to GPU if `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. You can load a frame asynchronously by setting `async_loading_frames` to `True`. """ if isinstance(video_path, str) and os.path.isdir(video_path): jpg_folder = video_path else: raise NotImplementedError("Only JPEG frames are supported at this moment") if frame_names is None: frame_names = [ p for p in os.listdir(jpg_folder) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) num_frames = len(frame_names) if num_frames == 0: raise RuntimeError(f"no images found in {jpg_folder}") img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] if async_loading_frames: lazy_images = AsyncVideoFrameLoader( img_paths, image_size, offload_video_to_cpu, img_mean, img_std ) return lazy_images, lazy_images.video_height, lazy_images.video_width images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) if not offload_video_to_cpu: images = images.cuda() img_mean = img_mean.cuda() img_std = img_std.cuda() # normalize by mean and std images -= img_mean images /= img_std return images, video_height, video_width def load_video_frames_v2( frames, image_size, offload_video_to_cpu, img_mean=(0.485, 0.456, 0.406), img_std=(0.229, 0.224, 0.225), async_loading_frames=False, frame_names=None, ): """ Load the video frames from a directory of JPEG files (".jpg" format). The frames are resized to image_size x image_size and are loaded to GPU if `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. You can load a frame asynchronously by setting `async_loading_frames` to `True`. """ num_frames = len(frames) img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) for n, frame in enumerate(tqdm(frames, desc="video frame")): images[n], video_height, video_width = _load_img_v2_as_tensor(frame, image_size) if not offload_video_to_cpu: images = images.cuda() img_mean = img_mean.cuda() img_std = img_std.cuda() # normalize by mean and std images -= img_mean images /= img_std return images, video_height, video_width def build_sam2_video_predictor( config_file, ckpt_path=None, device="cuda", mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, ): hydra_overrides = [ "++model._target_=video_predictor.SAM2VideoPredictor", ] if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ # dynamically fall back to multi-mask if the single mask is not stable "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking "++model.binarize_mask_from_pts_for_mem_enc=true", # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) "++model.fill_hole_area=8", ] hydra_overrides.extend(hydra_overrides_extra) # Read config and init model cfg = compose(config_name=config_file, overrides=hydra_overrides) OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path) model = model.to(device) if mode == "eval": model.eval() return model