|
|
import os |
|
|
from typing import List |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
from nested_attention_processor import AttnProcessor, NestedAttnProcessor |
|
|
from utils import get_generator |
|
|
|
|
|
from resampler import Resampler |
|
|
|
|
|
|
|
|
|
|
|
def add_special_token_to_tokenizer( |
|
|
pipe, |
|
|
placeholder_token, |
|
|
initializer_token |
|
|
): |
|
|
num_added_tokens1 = pipe.tokenizer.add_tokens([placeholder_token]) |
|
|
num_added_tokens2 = pipe.tokenizer_2.add_tokens([placeholder_token]) |
|
|
if num_added_tokens1 != 1 or num_added_tokens2 != 1: |
|
|
raise ValueError("Failed to add placeholder token to tokenizer") |
|
|
|
|
|
token_ids1 = pipe.tokenizer.encode(initializer_token, add_special_tokens=False) |
|
|
token_ids2 = pipe.tokenizer_2.encode(initializer_token, add_special_tokens=False) |
|
|
if len(token_ids1) > 1 or len(token_ids2) > 1: |
|
|
raise ValueError("The initializer token must be a single token.") |
|
|
initializer_token_id1 = token_ids1[0] |
|
|
initializer_token_id2 = token_ids2[0] |
|
|
placeholder_token_ids1 = pipe.tokenizer.convert_tokens_to_ids([placeholder_token]) |
|
|
placeholder_token_ids2 = pipe.tokenizer_2.convert_tokens_to_ids([placeholder_token]) |
|
|
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer)) |
|
|
pipe.text_encoder_2.resize_token_embeddings(len(pipe.tokenizer_2)) |
|
|
token_embeds1 = pipe.text_encoder.get_input_embeddings().weight.data |
|
|
token_embeds2 = pipe.text_encoder_2.get_input_embeddings().weight.data |
|
|
with torch.no_grad(): |
|
|
for token_id in placeholder_token_ids1: |
|
|
token_embeds1[token_id] = token_embeds1[initializer_token_id1].clone() |
|
|
for token_id in placeholder_token_ids2: |
|
|
token_embeds2[token_id] = token_embeds2[initializer_token_id2].clone() |
|
|
|
|
|
|
|
|
class NestedAdapterInference: |
|
|
def __init__( |
|
|
self, |
|
|
sd_pipe, |
|
|
image_encoder_path, |
|
|
adapter_ckpt, |
|
|
resampler_num_queries, |
|
|
vq_normalize_factor, |
|
|
device, |
|
|
): |
|
|
self.device = device |
|
|
self.image_encoder_path = image_encoder_path |
|
|
self.adapter_ckpt = adapter_ckpt |
|
|
|
|
|
self.vq_normalize_factor = vq_normalize_factor |
|
|
|
|
|
self.pipe = sd_pipe.to(self.device) |
|
|
self.set_nested_adapter() |
|
|
|
|
|
|
|
|
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
|
|
self.image_encoder_path, use_safetensors=True |
|
|
).to(self.device, dtype=torch.float16) |
|
|
self.clip_image_processor = CLIPImageProcessor() |
|
|
|
|
|
|
|
|
self.qformer = Resampler( |
|
|
dim=self.pipe.unet.config.cross_attention_dim, |
|
|
depth=4, |
|
|
dim_head=64, |
|
|
heads=12, |
|
|
num_queries=resampler_num_queries, |
|
|
embedding_dim=self.image_encoder.config.hidden_size, |
|
|
output_dim=self.pipe.unet.config.cross_attention_dim, |
|
|
ff_mult=4, |
|
|
).to(self.device, dtype=torch.float16) |
|
|
|
|
|
if adapter_ckpt is not None: |
|
|
self.load_nested_adapter() |
|
|
|
|
|
def set_nested_adapter(self): |
|
|
unet = self.pipe.unet |
|
|
attn_procs = {} |
|
|
for name in unet.attn_processors.keys(): |
|
|
cross_attention_dim = ( |
|
|
None |
|
|
if name.endswith("attn1.processor") |
|
|
else unet.config.cross_attention_dim |
|
|
) |
|
|
if name.startswith("mid_block"): |
|
|
hidden_size = unet.config.block_out_channels[-1] |
|
|
elif name.startswith("up_blocks"): |
|
|
block_id = int(name[len("up_blocks.")]) |
|
|
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] |
|
|
elif name.startswith("down_blocks"): |
|
|
block_id = int(name[len("down_blocks.")]) |
|
|
hidden_size = unet.config.block_out_channels[block_id] |
|
|
if cross_attention_dim is None: |
|
|
attn_procs[name] = AttnProcessor() |
|
|
else: |
|
|
attn_procs[name] = NestedAttnProcessor( |
|
|
hidden_size=hidden_size, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
normalize_factor=self.vq_normalize_factor, |
|
|
).to(self.device, dtype=torch.float16) |
|
|
unet.set_attn_processor(attn_procs) |
|
|
|
|
|
def load_nested_adapter(self): |
|
|
state_dict = {"adapter_modules": {}, "qformer": {}} |
|
|
f = load_file(self.adapter_ckpt) |
|
|
for key in f.keys(): |
|
|
if key.startswith("adapter_modules."): |
|
|
state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[ |
|
|
key |
|
|
] |
|
|
elif key.startswith("spatial_features_model."): |
|
|
state_dict["qformer"][key.replace("spatial_features_model.", "")] = f[ |
|
|
key |
|
|
] |
|
|
self.qformer.load_state_dict(state_dict["qformer"]) |
|
|
adapter_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) |
|
|
adapter_layers.load_state_dict(state_dict["adapter_modules"]) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def get_image_embeds(self, pil_image=None, clip_image_embeds=None): |
|
|
if isinstance(pil_image, Image.Image): |
|
|
pil_image = [pil_image] |
|
|
clip_image = self.clip_image_processor( |
|
|
images=pil_image, return_tensors="pt" |
|
|
).pixel_values |
|
|
clip_image_embeds = self.image_encoder( |
|
|
clip_image.to(self.device, dtype=torch.float16) |
|
|
) |
|
|
spatial_clip_image_embeds = clip_image_embeds.last_hidden_state |
|
|
spatial_clip_image_embeds = spatial_clip_image_embeds[:, 1:] |
|
|
return spatial_clip_image_embeds |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
pil_image=None, |
|
|
clip_image_embeds=None, |
|
|
prompt=None, |
|
|
placeholder_token_ids=None, |
|
|
negative_prompt=None, |
|
|
scale=1.0, |
|
|
num_samples=4, |
|
|
seed=None, |
|
|
guidance_scale=5.0, |
|
|
num_inference_steps=30, |
|
|
multiple_images=False, |
|
|
special_token_weight=1.0, |
|
|
**kwargs, |
|
|
): |
|
|
if pil_image is not None: |
|
|
num_prompts = ( |
|
|
1 |
|
|
if isinstance(pil_image, Image.Image) or multiple_images |
|
|
else len(pil_image) |
|
|
) |
|
|
else: |
|
|
num_prompts = clip_image_embeds.size(0) |
|
|
|
|
|
if prompt is None: |
|
|
prompt = "best quality, high quality" |
|
|
if negative_prompt is None: |
|
|
negative_prompt = ( |
|
|
"monochrome, lowres, bad anatomy, worst quality, low quality" |
|
|
) |
|
|
|
|
|
if not isinstance(prompt, List): |
|
|
prompt = [prompt] * num_prompts |
|
|
if not isinstance(negative_prompt, List): |
|
|
negative_prompt = [negative_prompt] * num_prompts |
|
|
|
|
|
text_input_ids = self.pipe.tokenizer( |
|
|
prompt, |
|
|
max_length=self.pipe.tokenizer.model_max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
).input_ids |
|
|
special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[ |
|
|
:, 1 |
|
|
] |
|
|
|
|
|
spatial_clip_image_embeds = self.get_image_embeds( |
|
|
pil_image=pil_image, clip_image_embeds=clip_image_embeds |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
( |
|
|
prompt_embeds, |
|
|
negative_prompt_embeds, |
|
|
pooled_prompt_embeds, |
|
|
negative_pooled_prompt_embeds, |
|
|
) = self.pipe.encode_prompt( |
|
|
prompt, |
|
|
num_images_per_prompt=num_samples, |
|
|
do_classifier_free_guidance=True, |
|
|
negative_prompt=negative_prompt, |
|
|
) |
|
|
|
|
|
special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[ |
|
|
:, 1 |
|
|
] |
|
|
|
|
|
with torch.no_grad(): |
|
|
qformer_tokens_out = self.qformer(spatial_clip_image_embeds) |
|
|
|
|
|
if multiple_images: |
|
|
b, num_tokens, d = qformer_tokens_out.shape |
|
|
qformer_tokens_out = qformer_tokens_out.reshape( |
|
|
1, num_tokens * b, d |
|
|
) |
|
|
|
|
|
bs_embed, num_tokens, _ = qformer_tokens_out.shape |
|
|
|
|
|
qformer_tokens_out = qformer_tokens_out.repeat(1, num_samples, 1, 1) |
|
|
qformer_tokens_out = qformer_tokens_out.view( |
|
|
bs_embed * num_samples, num_tokens, -1 |
|
|
) |
|
|
qformer_tokens_out = qformer_tokens_out.repeat_interleave(2, dim=0) |
|
|
|
|
|
cross_attention_kwargs = { |
|
|
"qformer_tokens_out": qformer_tokens_out, |
|
|
"special_token_indices": special_token_indices, |
|
|
"special_token_weight": special_token_weight, |
|
|
"inference_mode": True, |
|
|
} |
|
|
|
|
|
generator = get_generator(seed, self.device) |
|
|
|
|
|
images = self.pipe( |
|
|
prompt_embeds=prompt_embeds, |
|
|
negative_prompt_embeds=negative_prompt_embeds, |
|
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
generator=generator, |
|
|
cross_attention_kwargs=cross_attention_kwargs, |
|
|
**kwargs, |
|
|
).images |
|
|
|
|
|
return images |
|
|
|