Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import math | |
| from typing import Tuple, Union | |
| import torch | |
| import torch.cuda.amp as amp | |
| import torch.nn as nn | |
| from diffusers.models.attention import AdaLayerNorm | |
| from ..model import WanAttentionBlock, WanCrossAttention | |
| from .auxi_blocks import MotionEncoder_tc | |
| class CausalAudioEncoder(nn.Module): | |
| def __init__(self, | |
| dim=5120, | |
| num_layers=25, | |
| out_dim=2048, | |
| video_rate=8, | |
| num_token=4, | |
| need_global=False): | |
| super().__init__() | |
| self.encoder = MotionEncoder_tc( | |
| in_dim=dim, | |
| hidden_dim=out_dim, | |
| num_heads=num_token, | |
| need_global=need_global) | |
| weight = torch.ones((1, num_layers, 1, 1)) * 0.01 | |
| self.weights = torch.nn.Parameter(weight) | |
| self.act = torch.nn.SiLU() | |
| def forward(self, features): | |
| with amp.autocast(dtype=torch.float32): | |
| # features B * num_layers * dim * video_length | |
| weights = self.act(self.weights) | |
| weights_sum = weights.sum(dim=1, keepdims=True) | |
| weighted_feat = ((features * weights) / weights_sum).sum( | |
| dim=1) # b dim f | |
| weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim | |
| res = self.encoder(weighted_feat) # b f n dim | |
| return res # b f n dim | |
| class AudioCrossAttention(WanCrossAttention): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| class AudioInjector_WAN(nn.Module): | |
| def __init__(self, | |
| all_modules, | |
| all_modules_names, | |
| dim=2048, | |
| num_heads=32, | |
| inject_layer=[0, 27], | |
| root_net=None, | |
| enable_adain=False, | |
| adain_dim=2048, | |
| need_adain_ont=False): | |
| super().__init__() | |
| num_injector_layers = len(inject_layer) | |
| self.injected_block_id = {} | |
| audio_injector_id = 0 | |
| for mod_name, mod in zip(all_modules_names, all_modules): | |
| if isinstance(mod, WanAttentionBlock): | |
| for inject_id in inject_layer: | |
| if f'transformer_blocks.{inject_id}' in mod_name: | |
| self.injected_block_id[inject_id] = audio_injector_id | |
| audio_injector_id += 1 | |
| self.injector = nn.ModuleList([ | |
| AudioCrossAttention( | |
| dim=dim, | |
| num_heads=num_heads, | |
| qk_norm=True, | |
| ) for _ in range(audio_injector_id) | |
| ]) | |
| self.injector_pre_norm_feat = nn.ModuleList([ | |
| nn.LayerNorm( | |
| dim, | |
| elementwise_affine=False, | |
| eps=1e-6, | |
| ) for _ in range(audio_injector_id) | |
| ]) | |
| self.injector_pre_norm_vec = nn.ModuleList([ | |
| nn.LayerNorm( | |
| dim, | |
| elementwise_affine=False, | |
| eps=1e-6, | |
| ) for _ in range(audio_injector_id) | |
| ]) | |
| if enable_adain: | |
| self.injector_adain_layers = nn.ModuleList([ | |
| AdaLayerNorm( | |
| output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1) | |
| for _ in range(audio_injector_id) | |
| ]) | |
| if need_adain_ont: | |
| self.injector_adain_output_layers = nn.ModuleList( | |
| [nn.Linear(dim, dim) for _ in range(audio_injector_id)]) | |