Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import random | |
| import torch | |
| from safetensors.torch import safe_open, save_file | |
| from tqdm import tqdm | |
| def load_tensor(path: str, key: str = "embedding") -> torch.Tensor: | |
| with safe_open(path, framework="pt", device="cpu") as f: | |
| return f.get_tensor(key) | |
| def compute_global_stats(file_list, key="embedding", length_weighted=True): | |
| sum_all = None | |
| sum_sq_all = None | |
| count_all = 0 | |
| for path in tqdm(file_list, desc="Computing stats"): | |
| tensor = load_tensor(path, key) # shape: [B, T, D] | |
| flat = tensor.reshape(-1, tensor.shape[-1]) # [B*T, D] | |
| sum_ = flat.sum(dim=0) # [D] | |
| sum_sq = (flat**2).sum(dim=0) # [D] | |
| count = flat.shape[0] # B*T | |
| if sum_all is None: | |
| sum_all = sum_ | |
| sum_sq_all = sum_sq | |
| else: | |
| sum_all += sum_ | |
| sum_sq_all += sum_sq | |
| count_all += count | |
| mean = sum_all / count_all | |
| var = sum_sq_all / count_all - mean**2 | |
| std = torch.sqrt(torch.clamp(var, min=1e-8)) | |
| return mean, std | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "filelist", type=str, help="Text file with list of safetensors paths" | |
| ) | |
| parser.add_argument("output", type=str, help="Path to output stats.safetensors") | |
| parser.add_argument( | |
| "--key", type=str, default="audio_z", help="Key of tensor in safetensors file" | |
| ) | |
| parser.add_argument( | |
| "--max-files", type=int, default=None, help="Max number of files to process" | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=42, help="Random seed for shuffling" | |
| ) | |
| args = parser.parse_args() | |
| with open(args.filelist) as f: | |
| files = [line.strip() for line in f if line.strip()] | |
| if args.max_files: | |
| random.seed(args.seed) | |
| files = random.sample(files, k=min(args.max_files, len(files))) | |
| mean, std = compute_global_stats(files, key=args.key) | |
| save_file({"mean": mean, "std": std}, args.output) | |
| print(f"β Saved to {args.output}") | |
| print("Example mean/std:", mean[:5], std[:5]) | |
| if __name__ == "__main__": | |
| main() | |