Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| # ------------------------------------------------------------------------------ | |
| # Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py | |
| # Original licence: MIT License | |
| # ------------------------------------------------------------------------------ | |
| import math | |
| from typing import List, Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmengine.model import BaseModule | |
| from mmengine.runner import CheckpointLoader, load_checkpoint | |
| from mmseg.registry import MODELS | |
| from mmseg.utils import ConfigType, OptConfigType | |
| try: | |
| from ldm.modules.diffusionmodules.util import timestep_embedding | |
| from ldm.util import instantiate_from_config | |
| has_ldm = True | |
| except ImportError: | |
| has_ldm = False | |
| def register_attention_control(model, controller): | |
| """Registers a control function to manage attention within a model. | |
| Args: | |
| model: The model to which attention is to be registered. | |
| controller: The control function responsible for managing attention. | |
| """ | |
| def ca_forward(self, place_in_unet): | |
| """Custom forward method for attention. | |
| Args: | |
| self: Reference to the current object. | |
| place_in_unet: The location in UNet (down/mid/up). | |
| Returns: | |
| The modified forward method. | |
| """ | |
| def forward(x, context=None, mask=None): | |
| h = self.heads | |
| is_cross = context is not None | |
| context = context or x # if context is None, use x | |
| q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) | |
| q, k, v = ( | |
| tensor.view(tensor.shape[0] * h, tensor.shape[1], | |
| tensor.shape[2] // h) for tensor in [q, k, v]) | |
| sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale | |
| if mask is not None: | |
| mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1) | |
| max_neg_value = -torch.finfo(sim.dtype).max | |
| sim.masked_fill_(~mask, max_neg_value) | |
| attn = sim.softmax(dim=-1) | |
| attn_mean = attn.view(h, attn.shape[0] // h, | |
| *attn.shape[1:]).mean(0) | |
| controller(attn_mean, is_cross, place_in_unet) | |
| out = torch.matmul(attn, v) | |
| out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h) | |
| return self.to_out(out) | |
| return forward | |
| def register_recr(net_, count, place_in_unet): | |
| """Recursive function to register the custom forward method to all | |
| CrossAttention layers. | |
| Args: | |
| net_: The network layer currently being processed. | |
| count: The current count of layers processed. | |
| place_in_unet: The location in UNet (down/mid/up). | |
| Returns: | |
| The updated count of layers processed. | |
| """ | |
| if net_.__class__.__name__ == 'CrossAttention': | |
| net_.forward = ca_forward(net_, place_in_unet) | |
| return count + 1 | |
| if hasattr(net_, 'children'): | |
| return sum( | |
| register_recr(child, 0, place_in_unet) | |
| for child in net_.children()) | |
| return count | |
| cross_att_count = sum( | |
| register_recr(net[1], 0, place) for net, place in [ | |
| (child, 'down') if 'input_blocks' in name else ( | |
| child, 'up') if 'output_blocks' in name else | |
| (child, | |
| 'mid') if 'middle_block' in name else (None, None) # Default case | |
| for name, child in model.diffusion_model.named_children() | |
| ] if net is not None) | |
| controller.num_att_layers = cross_att_count | |
| class AttentionStore: | |
| """A class for storing attention information in the UNet model. | |
| Attributes: | |
| base_size (int): Base size for storing attention information. | |
| max_size (int): Maximum size for storing attention information. | |
| """ | |
| def __init__(self, base_size=64, max_size=None): | |
| """Initialize AttentionStore with default or custom sizes.""" | |
| self.reset() | |
| self.base_size = base_size | |
| self.max_size = max_size or (base_size // 2) | |
| self.num_att_layers = -1 | |
| def get_empty_store(): | |
| """Returns an empty store for holding attention values.""" | |
| return { | |
| key: [] | |
| for key in [ | |
| 'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self', | |
| 'up_self' | |
| ] | |
| } | |
| def reset(self): | |
| """Resets the step and attention stores to their initial states.""" | |
| self.cur_step = 0 | |
| self.cur_att_layer = 0 | |
| self.step_store = self.get_empty_store() | |
| self.attention_store = {} | |
| def forward(self, attn, is_cross: bool, place_in_unet: str): | |
| """Processes a single forward step, storing the attention. | |
| Args: | |
| attn: The attention tensor. | |
| is_cross (bool): Whether it's cross attention. | |
| place_in_unet (str): The location in UNet (down/mid/up). | |
| Returns: | |
| The unmodified attention tensor. | |
| """ | |
| key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
| if attn.shape[1] <= (self.max_size)**2: | |
| self.step_store[key].append(attn) | |
| return attn | |
| def between_steps(self): | |
| """Processes and stores attention information between steps.""" | |
| if not self.attention_store: | |
| self.attention_store = self.step_store | |
| else: | |
| for key in self.attention_store: | |
| self.attention_store[key] = [ | |
| stored + step for stored, step in zip( | |
| self.attention_store[key], self.step_store[key]) | |
| ] | |
| self.step_store = self.get_empty_store() | |
| def get_average_attention(self): | |
| """Calculates and returns the average attention across all steps.""" | |
| return { | |
| key: [item for item in self.step_store[key]] | |
| for key in self.step_store | |
| } | |
| def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
| """Allows the class instance to be callable.""" | |
| return self.forward(attn, is_cross, place_in_unet) | |
| def num_uncond_att_layers(self): | |
| """Returns the number of unconditional attention layers (default is | |
| 0).""" | |
| return 0 | |
| def step_callback(self, x_t): | |
| """A placeholder for a step callback. | |
| Returns the input unchanged. | |
| """ | |
| return x_t | |
| class UNetWrapper(nn.Module): | |
| """A wrapper for UNet with optional attention mechanisms. | |
| Args: | |
| unet (nn.Module): The UNet model to wrap | |
| use_attn (bool): Whether to use attention. Defaults to True | |
| base_size (int): Base size for the attention store. Defaults to 512 | |
| max_attn_size (int, optional): Maximum size for the attention store. | |
| Defaults to None | |
| attn_selector (str): The types of attention to use. | |
| Defaults to 'up_cross+down_cross' | |
| """ | |
| def __init__(self, | |
| unet, | |
| use_attn=True, | |
| base_size=512, | |
| max_attn_size=None, | |
| attn_selector='up_cross+down_cross'): | |
| super().__init__() | |
| assert has_ldm, 'To use UNetWrapper, please install required ' \ | |
| 'packages via `pip install -r requirements/optional.txt`.' | |
| self.unet = unet | |
| self.attention_store = AttentionStore( | |
| base_size=base_size // 8, max_size=max_attn_size) | |
| self.attn_selector = attn_selector.split('+') | |
| self.use_attn = use_attn | |
| self.init_sizes(base_size) | |
| if self.use_attn: | |
| register_attention_control(unet, self.attention_store) | |
| def init_sizes(self, base_size): | |
| """Initialize sizes based on the base size.""" | |
| self.size16 = base_size // 32 | |
| self.size32 = base_size // 16 | |
| self.size64 = base_size // 8 | |
| def forward(self, x, timesteps=None, context=None, y=None, **kwargs): | |
| """Forward pass through the model.""" | |
| diffusion_model = self.unet.diffusion_model | |
| if self.use_attn: | |
| self.attention_store.reset() | |
| hs, emb, out_list = self._unet_forward(x, timesteps, context, y, | |
| diffusion_model) | |
| if self.use_attn: | |
| self._append_attn_to_output(out_list) | |
| return out_list[::-1] | |
| def _unet_forward(self, x, timesteps, context, y, diffusion_model): | |
| hs = [] | |
| t_emb = timestep_embedding( | |
| timesteps, diffusion_model.model_channels, repeat_only=False) | |
| emb = diffusion_model.time_embed(t_emb) | |
| h = x.type(diffusion_model.dtype) | |
| for module in diffusion_model.input_blocks: | |
| h = module(h, emb, context) | |
| hs.append(h) | |
| h = diffusion_model.middle_block(h, emb, context) | |
| out_list = [] | |
| for i_out, module in enumerate(diffusion_model.output_blocks): | |
| h = torch.cat([h, hs.pop()], dim=1) | |
| h = module(h, emb, context) | |
| if i_out in [1, 4, 7]: | |
| out_list.append(h) | |
| h = h.type(x.dtype) | |
| out_list.append(h) | |
| return hs, emb, out_list | |
| def _append_attn_to_output(self, out_list): | |
| avg_attn = self.attention_store.get_average_attention() | |
| attns = {self.size16: [], self.size32: [], self.size64: []} | |
| for k in self.attn_selector: | |
| for up_attn in avg_attn[k]: | |
| size = int(math.sqrt(up_attn.shape[1])) | |
| up_attn = up_attn.transpose(-1, -2).reshape( | |
| *up_attn.shape[:2], size, -1) | |
| attns[size].append(up_attn) | |
| attn16 = torch.stack(attns[self.size16]).mean(0) | |
| attn32 = torch.stack(attns[self.size32]).mean(0) | |
| attn64 = torch.stack(attns[self.size64]).mean(0) if len( | |
| attns[self.size64]) > 0 else None | |
| out_list[1] = torch.cat([out_list[1], attn16], dim=1) | |
| out_list[2] = torch.cat([out_list[2], attn32], dim=1) | |
| if attn64 is not None: | |
| out_list[3] = torch.cat([out_list[3], attn64], dim=1) | |
| class TextAdapter(nn.Module): | |
| """A PyTorch Module that serves as a text adapter. | |
| This module takes text embeddings and adjusts them based on a scaling | |
| factor gamma. | |
| """ | |
| def __init__(self, text_dim=768): | |
| super().__init__() | |
| self.fc = nn.Sequential( | |
| nn.Linear(text_dim, text_dim), nn.GELU(), | |
| nn.Linear(text_dim, text_dim)) | |
| def forward(self, texts, gamma): | |
| texts_after = self.fc(texts) | |
| texts = texts + gamma * texts_after | |
| return texts | |
| class VPD(BaseModule): | |
| """VPD (Visual Perception Diffusion) model. | |
| .. _`VPD`: https://arxiv.org/abs/2303.02153 | |
| Args: | |
| diffusion_cfg (dict): Configuration for diffusion model. | |
| class_embed_path (str): Path for class embeddings. | |
| unet_cfg (dict, optional): Configuration for U-Net. | |
| gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4. | |
| class_embed_select (bool, optional): If True, enables class embedding | |
| selection. Defaults to False. | |
| pad_shape (Optional[Union[int, List[int]]], optional): Padding shape. | |
| Defaults to None. | |
| pad_val (Union[int, List[int]], optional): Padding value. | |
| Defaults to 0. | |
| init_cfg (dict, optional): Configuration for network initialization. | |
| """ | |
| def __init__(self, | |
| diffusion_cfg: ConfigType, | |
| class_embed_path: str, | |
| unet_cfg: OptConfigType = dict(), | |
| gamma: float = 1e-4, | |
| class_embed_select=False, | |
| pad_shape: Optional[Union[int, List[int]]] = None, | |
| pad_val: Union[int, List[int]] = 0, | |
| init_cfg: OptConfigType = None): | |
| super().__init__(init_cfg=init_cfg) | |
| assert has_ldm, 'To use VPD model, please install required packages' \ | |
| ' via `pip install -r requirements/optional.txt`.' | |
| if pad_shape is not None: | |
| if not isinstance(pad_shape, (list, tuple)): | |
| pad_shape = (pad_shape, pad_shape) | |
| self.pad_shape = pad_shape | |
| self.pad_val = pad_val | |
| # diffusion model | |
| diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None) | |
| sd_model = instantiate_from_config(diffusion_cfg) | |
| if diffusion_checkpoint is not None: | |
| load_checkpoint(sd_model, diffusion_checkpoint, strict=False) | |
| self.encoder_vq = sd_model.first_stage_model | |
| self.unet = UNetWrapper(sd_model.model, **unet_cfg) | |
| # class embeddings & text adapter | |
| class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path) | |
| text_dim = class_embeddings.size(-1) | |
| self.text_adapter = TextAdapter(text_dim=text_dim) | |
| self.class_embed_select = class_embed_select | |
| if class_embed_select: | |
| class_embeddings = torch.cat( | |
| (class_embeddings, class_embeddings.mean(dim=0, | |
| keepdims=True)), | |
| dim=0) | |
| self.register_buffer('class_embeddings', class_embeddings) | |
| self.gamma = nn.Parameter(torch.ones(text_dim) * gamma) | |
| def forward(self, x): | |
| """Extract features from images.""" | |
| # calculate cross-attn map | |
| if self.class_embed_select: | |
| if isinstance(x, (tuple, list)): | |
| x, class_ids = x[:2] | |
| class_ids = class_ids.tolist() | |
| else: | |
| class_ids = [-1] * x.size(0) | |
| class_embeddings = self.class_embeddings[class_ids] | |
| c_crossattn = self.text_adapter(class_embeddings, self.gamma) | |
| c_crossattn = c_crossattn.unsqueeze(1) | |
| else: | |
| class_embeddings = self.class_embeddings | |
| c_crossattn = self.text_adapter(class_embeddings, self.gamma) | |
| c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1) | |
| # pad to required input shape for pretrained diffusion model | |
| if self.pad_shape is not None: | |
| pad_width = max(0, self.pad_shape[1] - x.shape[-1]) | |
| pad_height = max(0, self.pad_shape[0] - x.shape[-2]) | |
| x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val) | |
| # forward the denoising model | |
| with torch.no_grad(): | |
| latents = self.encoder_vq.encode(x).mode().detach() | |
| t = torch.ones((x.shape[0], ), device=x.device).long() | |
| outs = self.unet(latents, t, context=c_crossattn) | |
| return outs | |