pardi-speech / codec /scripts /compute_stats.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import argparse
import random
import torch
from safetensors.torch import safe_open, save_file
from tqdm import tqdm
def load_tensor(path: str, key: str = "embedding") -> torch.Tensor:
with safe_open(path, framework="pt", device="cpu") as f:
return f.get_tensor(key)
def compute_global_stats(file_list, key="embedding", length_weighted=True):
sum_all = None
sum_sq_all = None
count_all = 0
for path in tqdm(file_list, desc="Computing stats"):
tensor = load_tensor(path, key) # shape: [B, T, D]
flat = tensor.reshape(-1, tensor.shape[-1]) # [B*T, D]
sum_ = flat.sum(dim=0) # [D]
sum_sq = (flat**2).sum(dim=0) # [D]
count = flat.shape[0] # B*T
if sum_all is None:
sum_all = sum_
sum_sq_all = sum_sq
else:
sum_all += sum_
sum_sq_all += sum_sq
count_all += count
mean = sum_all / count_all
var = sum_sq_all / count_all - mean**2
std = torch.sqrt(torch.clamp(var, min=1e-8))
return mean, std
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"filelist", type=str, help="Text file with list of safetensors paths"
)
parser.add_argument("output", type=str, help="Path to output stats.safetensors")
parser.add_argument(
"--key", type=str, default="audio_z", help="Key of tensor in safetensors file"
)
parser.add_argument(
"--max-files", type=int, default=None, help="Max number of files to process"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for shuffling"
)
args = parser.parse_args()
with open(args.filelist) as f:
files = [line.strip() for line in f if line.strip()]
if args.max_files:
random.seed(args.seed)
files = random.sample(files, k=min(args.max_files, len(files)))
mean, std = compute_global_stats(files, key=args.key)
save_file({"mean": mean, "std": std}, args.output)
print(f"βœ… Saved to {args.output}")
print("Example mean/std:", mean[:5], std[:5])
if __name__ == "__main__":
main()