# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import numpy as np import torch def rope_precompute(x, grid_sizes, freqs, start=None): b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 # split freqs if type(freqs) is list: trainable_freqs = freqs[1] freqs = freqs[0] freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # loop over samples output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64)) seq_bucket = [0] if not type(grid_sizes) is list: grid_sizes = [grid_sizes] for g in grid_sizes: if not type(g) is list: g = [torch.zeros_like(g), g] batch_size = g[0].shape[0] for i in range(batch_size): if start is None: f_o, h_o, w_o = g[0][i] else: f_o, h_o, w_o = start[i] f, h, w = g[1][i] t_f, t_h, t_w = g[2][i] seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o seq_len = int(seq_f * seq_h * seq_w) if seq_len > 0: if t_f > 0: factor_f, factor_h, factor_w = (t_f / seq_f).item(), ( t_h / seq_h).item(), (t_w / seq_w).item() # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) if f_o >= 0: f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist() else: f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist() h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist() w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist() assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][ f_sam].conj() freqs_0 = freqs_0.view(seq_f, 1, 1, -1) freqs_i = torch.cat([ freqs_0.expand(seq_f, seq_h, seq_w, -1), freqs[1][h_sam].view(1, seq_h, 1, -1).expand( seq_f, seq_h, seq_w, -1), freqs[2][w_sam].view(1, 1, seq_w, -1).expand( seq_f, seq_h, seq_w, -1), ], dim=-1).reshape(seq_len, 1, -1) elif t_f < 0: freqs_i = trainable_freqs.unsqueeze(1) # apply rotary embedding output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i seq_bucket.append(seq_bucket[-1] + seq_len) return output