|
|
import torch |
|
|
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, Lumina2Pipeline |
|
|
from transformers import AutoTokenizer, Gemma2Model |
|
|
import copy |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from diffusers.training_utils import ( |
|
|
cast_training_params, |
|
|
compute_density_for_timestep_sampling, |
|
|
compute_loss_weighting_for_sd3, |
|
|
free_memory, |
|
|
) |
|
|
from diffusers.pipelines.lumina2.pipeline_lumina2 import * |
|
|
|
|
|
class Lumina2Decoder(torch.nn.Module): |
|
|
def __init__(self, config, args): |
|
|
super().__init__() |
|
|
self.diffusion_model_path = config.model_path |
|
|
|
|
|
if not hasattr(args, "revision"): |
|
|
args.revision = None |
|
|
if not hasattr(args, "variant"): |
|
|
args.variant = None |
|
|
|
|
|
self.tokenizer_one = AutoTokenizer.from_pretrained( |
|
|
self.diffusion_model_path, |
|
|
subfolder="tokenizer", |
|
|
revision=args.revision, |
|
|
) |
|
|
self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( |
|
|
self.diffusion_model_path, subfolder="scheduler" |
|
|
) |
|
|
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler) |
|
|
self.text_encoder_one = Gemma2Model.from_pretrained( |
|
|
self.diffusion_model_path, subfolder="text_encoder", revision=args.revision, variant=args.variant |
|
|
) |
|
|
self.text_encoding_pipeline = Lumina2Pipeline.from_pretrained( |
|
|
self.diffusion_model_path, |
|
|
vae=None, |
|
|
transformer=None, |
|
|
text_encoder=self.text_encoder_one, |
|
|
tokenizer=self.tokenizer_one, |
|
|
) |
|
|
self.vae = AutoencoderKL.from_pretrained( |
|
|
self.diffusion_model_path, |
|
|
subfolder="vae", |
|
|
revision=args.revision, |
|
|
variant=args.variant, |
|
|
) |
|
|
if args.ori_inp_dit=="seq": |
|
|
from star.models.pixel_decoder.transformer_lumina2_seq import Lumina2Transformer2DModel |
|
|
elif args.ori_inp_dit=="ref": |
|
|
from star.models.pixel_decoder.transformer_lumina2 import Lumina2Transformer2DModel |
|
|
|
|
|
self.transformer = Lumina2Transformer2DModel.from_pretrained( |
|
|
self.diffusion_model_path, subfolder="transformer", revision=args.revision, variant=args.variant |
|
|
) |
|
|
|
|
|
vq_dim = 512 |
|
|
patch_size = self.transformer.config.patch_size |
|
|
in_channels = vq_dim + self.transformer.config.in_channels |
|
|
out_channels = self.transformer.x_embedder.out_features |
|
|
|
|
|
load_num_channel = self.transformer.config.in_channels * patch_size * patch_size |
|
|
self.transformer.register_to_config(in_channels=in_channels) |
|
|
transformer = self.transformer |
|
|
with torch.no_grad(): |
|
|
new_proj = nn.Linear( |
|
|
in_channels * patch_size * patch_size, out_channels, bias=True |
|
|
) |
|
|
|
|
|
new_proj.weight.zero_() |
|
|
|
|
|
new_proj = new_proj.to(transformer.x_embedder.weight.dtype) |
|
|
new_proj.weight[:, :load_num_channel].copy_(transformer.x_embedder.weight) |
|
|
new_proj.bias.copy_(transformer.x_embedder.bias) |
|
|
transformer.x_embedder = new_proj |
|
|
|
|
|
self.ori_inp_dit = args.ori_inp_dit |
|
|
if args.ori_inp_dit=="seq": |
|
|
refiner_channels = transformer.noise_refiner[-1].dim |
|
|
with torch.no_grad(): |
|
|
vae2cond_proj1 = nn.Linear(refiner_channels, refiner_channels, bias=True) |
|
|
vae2cond_act = nn.GELU(approximate='tanh') |
|
|
vae2cond_proj2 = nn.Linear(refiner_channels, refiner_channels, bias=False) |
|
|
vae2cond_proj2.weight.zero_() |
|
|
|
|
|
ori_inp_refiner = nn.Sequential( |
|
|
vae2cond_proj1, |
|
|
vae2cond_act, |
|
|
vae2cond_proj2 |
|
|
) |
|
|
transformer.ori_inp_refiner = ori_inp_refiner |
|
|
transformer.ori_inp_dit = self.ori_inp_dit |
|
|
elif args.ori_inp_dit=="ref": |
|
|
transformer.initialize_ref_weights() |
|
|
transformer.ori_inp_dit = self.ori_inp_dit |
|
|
|
|
|
transformer.requires_grad_(True) |
|
|
|
|
|
if args.grad_ckpt and args.diffusion_resolution==1024: |
|
|
transformer.gradient_checkpointing = args.grad_ckpt |
|
|
transformer.enable_gradient_checkpointing() |
|
|
|
|
|
self.vae.requires_grad_(False) |
|
|
self.vae.to(dtype=torch.float32) |
|
|
self.args = args |
|
|
|
|
|
self.pipe = Lumina2InstructPix2PixPipeline.from_pretrained(self.diffusion_model_path, |
|
|
transformer=transformer, |
|
|
text_encoder=self.text_encoder_one, |
|
|
vae=self.vae, |
|
|
torch_dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
_, _, self.uncond_prompt_embeds, self.uncond_prompt_attention_mask = self.text_encoding_pipeline.encode_prompt( |
|
|
"", |
|
|
max_sequence_length=self.args.max_diff_seq_length, |
|
|
) |
|
|
|
|
|
def compute_text_embeddings(self,prompt, text_encoding_pipeline): |
|
|
with torch.no_grad(): |
|
|
prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt( |
|
|
prompt, |
|
|
max_sequence_length=self.args.max_diff_seq_length, |
|
|
) |
|
|
return prompt_embeds, prompt_attention_mask |
|
|
|
|
|
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): |
|
|
sigmas = self.noise_scheduler_copy.sigmas.to(dtype=dtype) |
|
|
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device=timesteps.device) |
|
|
timesteps = timesteps |
|
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] |
|
|
|
|
|
sigma = sigmas[step_indices].flatten() |
|
|
while len(sigma.shape) < n_dim: |
|
|
sigma = sigma.unsqueeze(-1) |
|
|
return sigma |
|
|
|
|
|
def forward(self, batch_gpu,batch, image_embeds): |
|
|
args = self.args |
|
|
pixel_values = batch_gpu["full_pixel_values"].to(dtype=self.vae.dtype) |
|
|
data_type = "t2i" |
|
|
if len(pixel_values.shape)==5: |
|
|
bs,num_img,c,h,w = pixel_values.shape |
|
|
if num_img==2: |
|
|
data_type = "edit" |
|
|
pixel_values_ori_img = pixel_values[:,0] |
|
|
pixel_values = pixel_values[:,-1] |
|
|
pixel_values = F.interpolate(pixel_values, size=(self.args.diffusion_resolution, self.args.diffusion_resolution), mode='bilinear',align_corners=False) |
|
|
if data_type=="edit" and self.ori_inp_dit!="none": |
|
|
pixel_values_ori_img = F.interpolate(pixel_values_ori_img, size=(self.args.diffusion_resolution, self.args.diffusion_resolution), mode='bilinear', align_corners=False) |
|
|
prompt = batch["prompts"] |
|
|
bs,_,_,_ = pixel_values.shape |
|
|
image_prompt_embeds = None |
|
|
image_embeds_2d = image_embeds.reshape(bs, 24, 24, image_embeds.shape[-1]).permute(0, 3, 1, 2) |
|
|
image_embeds_2d = F.interpolate(image_embeds_2d, size=(args.diffusion_resolution//8, args.diffusion_resolution//8), mode='bilinear', align_corners=False) |
|
|
|
|
|
control_emd = args.control_emd |
|
|
prompt_embeds, prompt_attention_mask = self.compute_text_embeddings(prompt, self.text_encoding_pipeline) |
|
|
if control_emd=="mix": |
|
|
prompt_embeds=torch.cat([prompt_embeds, image_prompt_embeds], dim=1) |
|
|
elif control_emd=="null": |
|
|
prompt_embeds = torch.zeros_like(prompt_embeds) |
|
|
prompt_attention_mask = torch.ones_like(prompt_attention_mask) |
|
|
elif control_emd=="text": |
|
|
pass |
|
|
elif control_emd=="vit" or control_emd=="vq" or control_emd=="vqvae" or control_emd=="vqconcat" or control_emd=="vqconcatvit": |
|
|
prompt_embeds=image_prompt_embeds |
|
|
|
|
|
|
|
|
latents = self.vae.encode(pixel_values).latent_dist.sample() |
|
|
latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor |
|
|
latents = latents.to(dtype=image_embeds.dtype) |
|
|
|
|
|
latents_ori_img = torch.zeros_like(latents) |
|
|
if data_type=="edit" and self.ori_inp_dit!="none": |
|
|
latents_ori_img = self.vae.encode(pixel_values_ori_img).latent_dist.sample() |
|
|
latents_ori_img = (latents_ori_img - self.vae.config.shift_factor) * self.vae.config.scaling_factor |
|
|
latents_ori_img = latents_ori_img.to(dtype=image_embeds.dtype) |
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents) |
|
|
bsz = latents.shape[0] |
|
|
|
|
|
|
|
|
u = compute_density_for_timestep_sampling( |
|
|
weighting_scheme=args.weighting_scheme, |
|
|
batch_size=bsz, |
|
|
logit_mean=args.logit_mean, |
|
|
logit_std=args.logit_std, |
|
|
mode_scale=args.mode_scale, |
|
|
) |
|
|
|
|
|
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() |
|
|
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device) |
|
|
|
|
|
|
|
|
|
|
|
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype).to(device=noise.device) |
|
|
|
|
|
noisy_model_input = sigmas * noise + (1-sigmas) * latents |
|
|
|
|
|
|
|
|
|
|
|
original_image_embeds = image_embeds_2d |
|
|
|
|
|
if args.conditioning_dropout_prob is not None: |
|
|
random_p = torch.rand(bsz, device=latents.device) |
|
|
|
|
|
prompt_mask = random_p < 2 * args.uncondition_prob |
|
|
prompt_mask = prompt_mask.reshape(bsz, 1, 1) |
|
|
|
|
|
|
|
|
prompt_embeds = torch.where(prompt_mask, self.uncond_prompt_embeds.repeat(prompt_embeds.shape[0],1,1).to(prompt_embeds.device), prompt_embeds) |
|
|
prompt_attention_mask = torch.where(prompt_mask[:,0], self.uncond_prompt_attention_mask.repeat(prompt_embeds.shape[0],1).to(prompt_embeds.device), prompt_attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
image_mask_dtype = original_image_embeds.dtype |
|
|
image_mask = 1 - ( |
|
|
(random_p <= args.conditioning_dropout_prob).to(image_mask_dtype) |
|
|
) |
|
|
image_mask = image_mask.reshape(bsz, 1, 1, 1) |
|
|
|
|
|
if data_type=="edit": |
|
|
image_mask=0 |
|
|
|
|
|
original_image_embeds = image_mask * original_image_embeds |
|
|
|
|
|
ori_latent_mask = 1 - ( |
|
|
(random_p >= args.uncondition_prob).to(image_mask_dtype) |
|
|
* (random_p < 3 * args.uncondition_prob).to(image_mask_dtype) |
|
|
) |
|
|
ori_latent_mask = ori_latent_mask.reshape(bsz, 1, 1, 1) |
|
|
latents_ori_img = ori_latent_mask * latents_ori_img |
|
|
|
|
|
concatenated_noisy_latents = torch.cat([noisy_model_input, original_image_embeds], dim=1) |
|
|
|
|
|
ref_image_hidden_states = None |
|
|
if self.ori_inp_dit=="dim": |
|
|
concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, latents_ori_img], dim=1) |
|
|
elif self.ori_inp_dit=="seq": |
|
|
latents_ori_img = torch.cat([latents_ori_img, original_image_embeds], dim=1) |
|
|
concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, latents_ori_img], dim=2) |
|
|
elif self.ori_inp_dit=="ref": |
|
|
latents_ori_img = torch.cat([latents_ori_img, original_image_embeds], dim=1) |
|
|
ref_image_hidden_states = latents_ori_img[:,None] |
|
|
|
|
|
|
|
|
timesteps = 1-timesteps / self.noise_scheduler.config.num_train_timesteps |
|
|
model_pred = self.transformer( |
|
|
hidden_states=concatenated_noisy_latents, |
|
|
timestep=timesteps, |
|
|
encoder_hidden_states=prompt_embeds, |
|
|
encoder_attention_mask=prompt_attention_mask, |
|
|
|
|
|
return_dict=False, |
|
|
)[0] |
|
|
if self.ori_inp_dit=="seq": |
|
|
model_pred = model_pred[:, :, :args.diffusion_resolution//8, :] |
|
|
|
|
|
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) |
|
|
target = latents - noise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = torch.mean( |
|
|
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), |
|
|
1, |
|
|
) |
|
|
loss = loss.mean() |
|
|
|
|
|
loss_value = loss.item() |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
class Lumina2InstructPix2PixPipeline(Lumina2Pipeline): |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, |
|
|
prompt: Union[str, List[str]] = None, |
|
|
width: Optional[int] = None, |
|
|
height: Optional[int] = None, |
|
|
num_inference_steps: int = 30, |
|
|
guidance_scale: float = 4.0, |
|
|
negative_prompt: Union[str, List[str]] = None, |
|
|
sigmas: List[float] = None, |
|
|
num_images_per_prompt: Optional[int] = 1, |
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
|
latents: Optional[torch.Tensor] = None, |
|
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
|
prompt_attention_mask: Optional[torch.Tensor] = None, |
|
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
|
|
output_type: Optional[str] = "pil", |
|
|
return_dict: bool = True, |
|
|
attention_kwargs: Optional[Dict[str, Any]] = None, |
|
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
|
system_prompt: Optional[str] = None, |
|
|
cfg_trunc_ratio=[0.0,1.0], |
|
|
cfg_normalization: bool = False, |
|
|
max_sequence_length: int = 256, |
|
|
control_emd="text", |
|
|
img_cfg_trunc_ratio =[0.0,1.0], |
|
|
gen_image_embeds=None,only_t2i="vqconcat",image_prompt_embeds=None,ori_inp_img=None,img_guidance_scale=1.5,vq_guidance_scale=0,ori_inp_way="none", |
|
|
) -> Union[ImagePipelineOutput, Tuple]: |
|
|
|
|
|
height = height or self.default_sample_size * self.vae_scale_factor |
|
|
width = width or self.default_sample_size * self.vae_scale_factor |
|
|
self._guidance_scale = guidance_scale |
|
|
self._attention_kwargs = attention_kwargs |
|
|
|
|
|
num_images_per_prompt = gen_image_embeds.shape[0] if gen_image_embeds is not None else image_prompt_embeds.shape[0] |
|
|
|
|
|
self.check_inputs( |
|
|
prompt, |
|
|
height, |
|
|
width, |
|
|
negative_prompt, |
|
|
prompt_embeds=prompt_embeds, |
|
|
negative_prompt_embeds=negative_prompt_embeds, |
|
|
prompt_attention_mask=prompt_attention_mask, |
|
|
negative_prompt_attention_mask=negative_prompt_attention_mask, |
|
|
max_sequence_length=max_sequence_length, |
|
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
|
|
) |
|
|
|
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
|
batch_size = 1 |
|
|
elif prompt is not None and isinstance(prompt, list): |
|
|
batch_size = len(prompt) |
|
|
else: |
|
|
batch_size = prompt_embeds.shape[0] |
|
|
|
|
|
device = self._execution_device |
|
|
|
|
|
|
|
|
( |
|
|
prompt_embeds, |
|
|
prompt_attention_mask, |
|
|
negative_prompt_embeds, |
|
|
negative_prompt_attention_mask, |
|
|
) = self.encode_prompt( |
|
|
prompt, |
|
|
self.do_classifier_free_guidance, |
|
|
negative_prompt=negative_prompt, |
|
|
num_images_per_prompt=num_images_per_prompt, |
|
|
device=device, |
|
|
prompt_embeds=prompt_embeds, |
|
|
negative_prompt_embeds=negative_prompt_embeds, |
|
|
prompt_attention_mask=prompt_attention_mask, |
|
|
negative_prompt_attention_mask=negative_prompt_attention_mask, |
|
|
max_sequence_length=max_sequence_length, |
|
|
system_prompt=system_prompt, |
|
|
) |
|
|
|
|
|
|
|
|
if gen_image_embeds is not None: |
|
|
image_embeds_8=gen_image_embeds |
|
|
|
|
|
if control_emd=="text": |
|
|
pass |
|
|
elif control_emd=="null": |
|
|
prompt_embeds = torch.zeros_like(prompt_embeds) |
|
|
prompt_attention_mask = torch.zeros_like(prompt_attention_mask) |
|
|
negative_prompt_embeds = prompt_embeds |
|
|
negative_prompt_attention_mask = prompt_attention_mask |
|
|
|
|
|
if self.do_classifier_free_guidance: |
|
|
prompt_embeds = torch.cat([negative_prompt_embeds,negative_prompt_embeds, prompt_embeds], dim=0) |
|
|
prompt_attention_mask = torch.cat([negative_prompt_attention_mask,negative_prompt_attention_mask, prompt_attention_mask], dim=0) |
|
|
|
|
|
latent_channels = self.vae.config.latent_channels |
|
|
latents = self.prepare_latents( |
|
|
batch_size * num_images_per_prompt, |
|
|
latent_channels, |
|
|
height, |
|
|
width, |
|
|
prompt_embeds.dtype, |
|
|
device, |
|
|
generator, |
|
|
latents, |
|
|
) |
|
|
|
|
|
latents_ori_img = torch.zeros_like(latents) |
|
|
if ori_inp_img is not None and ori_inp_way !="none": |
|
|
|
|
|
ori_inp_img = F.interpolate(ori_inp_img[None].to(latents.device,latents.dtype), size=(height,width), mode='bilinear',align_corners=False) |
|
|
latents_ori_img = self.vae.encode(ori_inp_img).latent_dist.sample() |
|
|
latents_ori_img = (latents_ori_img- self.vae.config.shift_factor) * self.vae.config.scaling_factor |
|
|
latents_ori_img = latents_ori_img.to(dtype=latents.dtype) |
|
|
if ori_inp_way !="none": |
|
|
negative_latents_ori_img = torch.zeros_like(latents_ori_img).to(prompt_embeds.dtype) |
|
|
latents_ori_img = torch.cat([negative_latents_ori_img,latents_ori_img, latents_ori_img], dim=0) if self.do_classifier_free_guidance else latents_ori_img |
|
|
|
|
|
vq_in_edit = False |
|
|
if only_t2i==True: |
|
|
image_latents = torch.zeros_like(latents)[:,:8] |
|
|
elif only_t2i=="vqconcat": |
|
|
image_embeds_2d = image_embeds_8.reshape(batch_size* num_images_per_prompt,24,24,image_embeds_8.shape[-1]).permute(0,3,1,2) |
|
|
if ori_inp_img is not None and image_embeds_8.mean()!=0: |
|
|
vq_in_edit = True |
|
|
image_vq_latents = F.interpolate(image_embeds_2d, size=(height//8,width//8), mode='bilinear',align_corners=False).to(latents.device,latents.dtype) |
|
|
image_latents = torch.zeros_like(image_vq_latents) |
|
|
else: |
|
|
image_latents = F.interpolate(image_embeds_2d, size=(height//8,width//8), mode='bilinear',align_corners=False).to(latents.device,latents.dtype) |
|
|
|
|
|
negative_image_latents = torch.zeros_like(image_latents).to(prompt_embeds.dtype) |
|
|
image_latents = torch.cat([negative_image_latents,image_latents, image_latents], dim=0) if self.do_classifier_free_guidance else image_latents |
|
|
|
|
|
|
|
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas |
|
|
image_seq_len = latents.shape[1] |
|
|
mu = calculate_shift( |
|
|
image_seq_len, |
|
|
self.scheduler.config.get("base_image_seq_len", 256), |
|
|
self.scheduler.config.get("max_image_seq_len", 4096), |
|
|
self.scheduler.config.get("base_shift", 0.5), |
|
|
self.scheduler.config.get("max_shift", 1.15), |
|
|
) |
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
|
self.scheduler, |
|
|
num_inference_steps, |
|
|
device, |
|
|
sigmas=sigmas, |
|
|
mu=mu, |
|
|
) |
|
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
|
self._num_timesteps = len(timesteps) |
|
|
|
|
|
self.scheduler.sigmas=self.scheduler.sigmas.to(latents.dtype) |
|
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
|
for i, t in enumerate(timesteps): |
|
|
|
|
|
do_classifier_free_truncation = not ((i + 1) / num_inference_steps > cfg_trunc_ratio[0] and (i + 1) / num_inference_steps < cfg_trunc_ratio[1]) |
|
|
img_do_classifier_free_truncation = not ((i + 1) / num_inference_steps > img_cfg_trunc_ratio[0] and (i + 1) / num_inference_steps < img_cfg_trunc_ratio[1]) |
|
|
|
|
|
|
|
|
current_timestep = 1 - t / self.scheduler.config.num_train_timesteps |
|
|
|
|
|
latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents |
|
|
|
|
|
current_timestep = current_timestep.expand(latent_model_input.shape[0]) |
|
|
|
|
|
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) |
|
|
|
|
|
ref_image_hidden_states = None |
|
|
if ori_inp_way=="seq": |
|
|
latents_ori_img_cat = torch.cat([latents_ori_img, image_latents], dim=1) |
|
|
latent_model_input = torch.cat([latent_model_input, latents_ori_img_cat], dim=2) |
|
|
elif ori_inp_way=="ref": |
|
|
latents_ori_img_cat = torch.cat([latents_ori_img, image_latents], dim=1) |
|
|
ref_image_hidden_states = latents_ori_img_cat[:,None] |
|
|
|
|
|
if ori_inp_way=="ref": |
|
|
noise_pred = self.transformer( |
|
|
hidden_states=latent_model_input, |
|
|
timestep=current_timestep, |
|
|
encoder_hidden_states=prompt_embeds, |
|
|
encoder_attention_mask=prompt_attention_mask, |
|
|
return_dict=False,ref_image_hidden_states=ref_image_hidden_states, |
|
|
attention_kwargs=self.attention_kwargs, |
|
|
)[0] |
|
|
else: |
|
|
noise_pred = self.transformer( |
|
|
hidden_states=latent_model_input, |
|
|
timestep=current_timestep, |
|
|
encoder_hidden_states=prompt_embeds, |
|
|
encoder_attention_mask=prompt_attention_mask, |
|
|
return_dict=False, |
|
|
attention_kwargs=self.attention_kwargs, |
|
|
)[0] |
|
|
if ori_inp_way=="seq": |
|
|
noise_pred = noise_pred[:,:,:height//8,:] |
|
|
|
|
|
if vq_in_edit: |
|
|
latent_model_vq_input = torch.cat([latents, image_vq_latents], dim=1) |
|
|
if ori_inp_way=="seq": |
|
|
latents_ori_img_cat_vq = torch.cat([torch.zeros_like(latents), image_vq_latents], dim=1) |
|
|
latent_model_vq_input = torch.cat([latent_model_vq_input, latents_ori_img_cat_vq], dim=2) |
|
|
|
|
|
noise_vq_pred = self.transformer( |
|
|
hidden_states=latent_model_vq_input, |
|
|
timestep=current_timestep[-1:], |
|
|
encoder_hidden_states=prompt_embeds[-1:], |
|
|
encoder_attention_mask=prompt_attention_mask[-1:], |
|
|
return_dict=False, |
|
|
attention_kwargs=self.attention_kwargs, |
|
|
)[0] |
|
|
if ori_inp_way=="seq": |
|
|
noise_vq_pred = noise_vq_pred[:,:,:height//8,:] |
|
|
|
|
|
if self.do_classifier_free_guidance: |
|
|
noise_pred_uncond,noise_pred_img, noise_pred_text = noise_pred.chunk(3) |
|
|
if not do_classifier_free_truncation and not img_do_classifier_free_truncation: |
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_img)+ img_guidance_scale * (noise_pred_img - noise_pred_uncond) |
|
|
elif not do_classifier_free_truncation and img_do_classifier_free_truncation: |
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_img)+ 1 * (noise_pred_img - noise_pred_uncond) |
|
|
elif do_classifier_free_truncation and not img_do_classifier_free_truncation: |
|
|
noise_pred = noise_pred_uncond + 1 * (noise_pred_text - noise_pred_img)+ img_guidance_scale * (noise_pred_img - noise_pred_uncond) |
|
|
else: |
|
|
noise_pred = noise_pred_text |
|
|
if vq_in_edit: |
|
|
noise_pred = noise_pred +vq_guidance_scale*(noise_vq_pred-noise_pred_uncond) |
|
|
|
|
|
if cfg_normalization: |
|
|
cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) |
|
|
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) |
|
|
noise_pred = noise_pred * (cond_norm / noise_norm) |
|
|
else: |
|
|
noise_pred = noise_pred |
|
|
|
|
|
|
|
|
latents_dtype = latents.dtype |
|
|
noise_pred = -noise_pred |
|
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
|
|
|
|
|
if latents.dtype != latents_dtype: |
|
|
if torch.backends.mps.is_available(): |
|
|
|
|
|
latents = latents.to(latents_dtype) |
|
|
|
|
|
if callback_on_step_end is not None: |
|
|
callback_kwargs = {} |
|
|
for k in callback_on_step_end_tensor_inputs: |
|
|
callback_kwargs[k] = locals()[k] |
|
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
|
|
|
latents = callback_outputs.pop("latents", latents) |
|
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
|
|
|
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
|
progress_bar.update() |
|
|
|
|
|
if XLA_AVAILABLE: |
|
|
xm.mark_step() |
|
|
|
|
|
if not output_type == "latent": |
|
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
|
|
image = self.vae.decode(latents, return_dict=False)[0] |
|
|
image = self.image_processor.postprocess(image, output_type=output_type) |
|
|
else: |
|
|
image = latents |
|
|
|
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
|
|
if not return_dict: |
|
|
return (image,) |
|
|
|
|
|
return ImagePipelineOutput(images=image) |
|
|
|