Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Modified from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py | |
| from transformers import pipeline | |
| import torchvision | |
| from PIL import Image | |
| from models.t2i_pipeline import StableDiffusionPipelineSpatialAware | |
| import torchvision.io as vision_io | |
| import torch.nn.functional as F | |
| import torch | |
| import tqdm | |
| import numpy as np | |
| import cv2 | |
| import warnings | |
| import time | |
| import tempfile | |
| import argparse | |
| import glob | |
| import multiprocessing as mp | |
| import os | |
| import random | |
| # fmt: off | |
| import sys | |
| sys.path.insert(1, os.path.join(sys.path[0], '..')) | |
| # fmt: on | |
| warnings.filterwarnings("ignore") | |
| # constants | |
| WINDOW_NAME = "demo" | |
| def generate_image(pipe, overall_prompt, latents, get_latents=False, num_inference_steps=50, fg_masks=None, | |
| fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None): | |
| ''' | |
| Main function that calls the image diffusion model | |
| latent: input_noise from where it starts the generation | |
| get_latents: if True, returns the latents for each frame | |
| ''' | |
| image = pipe(overall_prompt, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks, | |
| frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, output_type='pil', | |
| fg_prompt=fg_prompt, make_attention_mask_2d=True, attention_mask_block_diagonal=True).images[0] | |
| torch.save(image, "img.pt") | |
| if get_latents: | |
| video_latents = pipe(overall_prompt, latents=latents, | |
| num_inference_steps=num_inference_steps, output_type="latent").images | |
| torch.save(video_latents, "img_latents.pt") | |
| return image, video_latents | |
| return image | |
| def save_frames(path): | |
| video, audio, video_info = vision_io.read_video( | |
| f"demo3/{path}.mp4", pts_unit='sec') | |
| # Number of frames | |
| num_frames = video.size(0) | |
| # Save each frame | |
| os.makedirs(f"demo3/{path}", exist_ok=True) | |
| for i in range(num_frames): | |
| frame = video[i, :, :, :].numpy() | |
| # Convert from C x H x W to H x W x C and from torch tensor to PIL Image | |
| # frame = frame.permute(1, 2, 0).numpy() | |
| img = Image.fromarray(frame.astype('uint8')) | |
| img.save(f"demo3/{path}/frame_{i:04d}.png") | |
| def create_boxes(): | |
| img_width = 96 | |
| img_height = 96 | |
| # initialize bboxes list | |
| sbboxes = [] | |
| # object dimensions | |
| for object_size in [20, 30, 40, 50, 60]: | |
| obj_width, obj_height = object_size, object_size | |
| # starting position | |
| start_x = 3 | |
| start_y = 4 | |
| # calculate total size occupied by the objects in the grid | |
| total_obj_width = 3 * obj_width | |
| total_obj_height = 3 * obj_height | |
| # determine horizontal and vertical spacings | |
| spacing_horizontal = (img_width - total_obj_width - start_x) // 2 | |
| spacing_vertical = (img_height - total_obj_height - start_y) // 2 | |
| for i in range(3): | |
| for j in range(3): | |
| x_start = start_x + i * (obj_width + spacing_horizontal) | |
| y_start = start_y + j * (obj_height + spacing_vertical) | |
| # Corrected to img_width to include the last pixel | |
| x_end = min(x_start + obj_width, img_width) | |
| # Corrected to img_height to include the last pixel | |
| y_end = min(y_start + obj_height, img_height) | |
| sbboxes.append([x_start, y_start, x_end, y_end]) | |
| mask_id = 0 | |
| masks_list = [] | |
| for sbbox in sbboxes: | |
| smask = torch.zeros(1, 1, 96, 96) | |
| smask[0, 0, sbbox[1]:sbbox[3], sbbox[0]:sbbox[2]] = 1.0 | |
| masks_list.append(smask) | |
| # torchvision.utils.save_image(smask, f"{SAVE_DIR}/masks/mask_{mask_id}.png") # save masks as images | |
| mask_id += 1 | |
| return masks_list | |
| def objects_list(): | |
| objects_settings = [ | |
| ("apple", "on a table"), | |
| ("ball", "in a park"), | |
| ("cat", "on a couch"), | |
| ("dog", "in a backyard"), | |
| ("elephant", "in a jungle"), | |
| ("fountain pen", "on a desk"), | |
| ("guitar", "on a stage"), | |
| ("helicopter", "in the sky"), | |
| ("island", "in the sea"), | |
| ("jar", "on a shelf"), | |
| ("kite", "in the sky"), | |
| ("lamp", "in a room"), | |
| ("motorbike", "on a road"), | |
| ("notebook", "on a table"), | |
| ("owl", "on a tree"), | |
| ("piano", "in a hall"), | |
| ("queen", "in a castle"), | |
| ("robot", "in a lab"), | |
| ("snake", "in a forest"), | |
| ("tent", "in the mountains"), | |
| ("umbrella", "on a beach"), | |
| ("violin", "in an orchestra"), | |
| ("wheel", "in a garage"), | |
| ("xylophone", "in a music class"), | |
| ("yacht", "in a marina"), | |
| ("zebra", "in a savannah"), | |
| ("aeroplane", "in the clouds"), | |
| ("bridge", "over a river"), | |
| ("computer", "in an office"), | |
| ("dragon", "in a cave"), | |
| ("egg", "in a nest"), | |
| ("flower", "in a garden"), | |
| ("globe", "in a library"), | |
| ("hat", "on a rack"), | |
| ("ice cube", "in a glass"), | |
| ("jewelry", "in a box"), | |
| ("kangaroo", "in a desert"), | |
| ("lion", "in a den"), | |
| ("mug", "on a counter"), | |
| ("nest", "on a branch"), | |
| ("octopus", "in the ocean"), | |
| ("parrot", "in a rainforest"), | |
| ("quilt", "on a bed"), | |
| ("rose", "in a vase"), | |
| ("ship", "in a dock"), | |
| ("train", "on the tracks"), | |
| ("utensils", "in a kitchen"), | |
| ("vase", "on a window sill"), | |
| ("watch", "in a store"), | |
| ("x-ray", "in a hospital"), | |
| ("yarn", "in a basket"), | |
| ("zeppelin", "above a city"), | |
| ] | |
| objects_settings.extend([ | |
| ("muffin", "on a bakery shelf"), | |
| ("notebook", "on a student's desk"), | |
| ("owl", "in a tree"), | |
| ("piano", "in a concert hall"), | |
| ("quill", "on parchment"), | |
| ("robot", "in a factory"), | |
| ("snake", "in the grass"), | |
| ("telescope", "in an observatory"), | |
| ("umbrella", "at the beach"), | |
| ("violin", "in an orchestra"), | |
| ("whale", "in the ocean"), | |
| ("xylophone", "in a music store"), | |
| ("yacht", "in a marina"), | |
| ("zebra", "on a savanna"), | |
| # Kitchen items | |
| ("spoon", "in a drawer"), | |
| ("plate", "in a cupboard"), | |
| ("cup", "on a shelf"), | |
| ("frying pan", "on a stove"), | |
| ("jar", "in the refrigerator"), | |
| # Office items | |
| ("computer", "in an office"), | |
| ("printer", "by a desk"), | |
| ("chair", "around a conference table"), | |
| ("lamp", "on a workbench"), | |
| ("calendar", "on a wall"), | |
| # Outdoor items | |
| ("bicycle", "on a street"), | |
| ("tent", "in a campsite"), | |
| ("fire", "in a fireplace"), | |
| ("mountain", "in the distance"), | |
| ("river", "through the woods"), | |
| # and so on ... | |
| ]) | |
| # To expedite the generation, you can combine themes and objects: | |
| themes = [ | |
| ("wild animals", ["tiger", "lion", "cheetah", | |
| "giraffe", "hippopotamus"], "in the wild"), | |
| ("household items", ["sofa", "tv", "clock", | |
| "vase", "photo frame"], "in a living room"), | |
| ("clothes", ["shirt", "pants", "shoes", | |
| "hat", "jacket"], "in a wardrobe"), | |
| ("musical instruments", ["drum", "trumpet", | |
| "harp", "saxophone", "tuba"], "in a band"), | |
| ("cosmic entities", ["planet", "star", | |
| "comet", "nebula", "asteroid"], "in space"), | |
| # ... add more themes | |
| ] | |
| # Using the themes to extend our list | |
| for theme_name, theme_objects, theme_location in themes: | |
| for theme_object in theme_objects: | |
| objects_settings.append((theme_object, theme_location)) | |
| # Sports equipment | |
| objects_settings.extend([ | |
| ("basketball", "on a court"), | |
| ("golf ball", "on a golf course"), | |
| ("tennis racket", "on a tennis court"), | |
| ("baseball bat", "in a stadium"), | |
| ("hockey stick", "on an ice rink"), | |
| ("football", "on a field"), | |
| ("skateboard", "in a skatepark"), | |
| ("boxing gloves", "in a boxing ring"), | |
| ("ski", "on a snowy slope"), | |
| ("surfboard", "on a beach shore"), | |
| ]) | |
| # Toys and games | |
| objects_settings.extend([ | |
| ("teddy bear", "on a child's bed"), | |
| ("doll", "in a toy store"), | |
| ("toy car", "on a carpet"), | |
| ("board game", "on a table"), | |
| ("yo-yo", "in a child's hand"), | |
| ("kite", "in the sky on a windy day"), | |
| ("Lego bricks", "on a construction table"), | |
| ("jigsaw puzzle", "partially completed"), | |
| ("rubik's cube", "on a shelf"), | |
| ("action figure", "on display"), | |
| ]) | |
| # Transportation | |
| objects_settings.extend([ | |
| ("bus", "at a bus stop"), | |
| ("motorcycle", "on a road"), | |
| ("helicopter", "landing on a pad"), | |
| ("scooter", "on a sidewalk"), | |
| ("train", "at a station"), | |
| ("bicycle", "parked by a post"), | |
| ("boat", "in a harbor"), | |
| ("tractor", "on a farm"), | |
| ("airplane", "taking off from a runway"), | |
| ("submarine", "below sea level"), | |
| ]) | |
| # Medieval theme | |
| objects_settings.extend([ | |
| ("castle", "on a hilltop"), | |
| ("knight", "riding a horse"), | |
| ("bow and arrow", "in an archery range"), | |
| ("crown", "in a treasure chest"), | |
| ("dragon", "flying over mountains"), | |
| ("shield", "next to a warrior"), | |
| ("dagger", "on a wooden table"), | |
| ("torch", "lighting a dark corridor"), | |
| ("scroll", "sealed with wax"), | |
| ("cauldron", "with bubbling potion"), | |
| ]) | |
| # Modern technology | |
| objects_settings.extend([ | |
| ("smartphone", "on a charger"), | |
| ("laptop", "in a cafe"), | |
| ("headphones", "around a neck"), | |
| ("camera", "on a tripod"), | |
| ("drone", "flying over a park"), | |
| ("USB stick", "plugged into a computer"), | |
| ("watch", "on a wrist"), | |
| ("microphone", "on a podcast desk"), | |
| ("tablet", "with a digital pen"), | |
| ("VR headset", "ready for gaming"), | |
| ]) | |
| # Nature | |
| objects_settings.extend([ | |
| ("tree", "in a forest"), | |
| ("flower", "in a garden"), | |
| ("mountain", "on a horizon"), | |
| ("cloud", "in a blue sky"), | |
| ("waterfall", "in a scenic location"), | |
| ("beach", "next to an ocean"), | |
| ("cactus", "in a desert"), | |
| ("volcano", "erupting with lava"), | |
| ("coral", "under the sea"), | |
| ("moon", "in a night sky"), | |
| ]) | |
| prompts = [f"A {obj} {setting}" for obj, setting in objects_settings] | |
| return objects_settings | |
| if __name__ == "__main__": | |
| SAVE_DIR = "/scr/image/" | |
| save_path = "img43-att_mask" | |
| torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| random_latents = torch.randn( | |
| [1, 4, 96, 96], generator=torch.Generator().manual_seed(1)).to(torch_device) | |
| try: | |
| pipe = StableDiffusionPipelineSpatialAware.from_pretrained( | |
| "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32", cache_dir="/gscratch/scrubbed/anasery/").to(torch_device) | |
| except: | |
| pipe = StableDiffusionPipelineSpatialAware.from_pretrained( | |
| "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32").to(torch_device) | |
| fg_object = "apple" # fg object stores the object to be masked | |
| # overall prompt stores the prompt | |
| overall_prompt = f"An {fg_object} on plate" | |
| os.makedirs(f"{SAVE_DIR}/{overall_prompt}", exist_ok=True) | |
| masks_list = create_boxes() | |
| # torch.save(f"{overall_prompt}+masked", "prompt.pt") | |
| obj_settings = objects_list() # 166 | |
| for obj_setting in obj_settings[120:]: | |
| fg_object = obj_setting[0] | |
| overall_prompt = f"A {obj_setting[0]} {obj_setting[1]}" | |
| print(overall_prompt) | |
| # randomly select 10 numbers from range len of masks_list | |
| selected_mask_ids = random.sample(range(len(masks_list)), 3) | |
| for mask_id in selected_mask_ids: | |
| os.makedirs( | |
| f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}", exist_ok=True) | |
| torchvision.utils.save_image( | |
| masks_list[mask_id][0][0], f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/mask.png") | |
| for frozen_steps in range(0, 5): | |
| img = generate_image(pipe, overall_prompt, random_latents, get_latents=False, num_inference_steps=50, fg_masks=masks_list[mask_id].to( | |
| torch_device), fg_masked_latents=None, frozen_steps=frozen_steps, frozen_prompt=None, fg_prompt=fg_object) | |
| img.save( | |
| f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/{frozen_steps}.png") | |