Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import json
import math
import os
import sys
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
import torch
from safetensors.torch import load_file
from torch import nn
from torchdyn.core import NeuralODE
from .modules import AdaLNFlowPredictor, AutoEncoder
@contextmanager
def suppress_stdout():
original_stdout = sys.stdout
try:
sys.stdout = open(os.devnull, "w")
yield
finally:
sys.stdout.close()
sys.stdout = original_stdout
def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
return lr_lambda
@dataclass
class PatchVAEConfig:
latent_dim: int
hidden_dim: int
latent_scaling: tuple[list[float], list[float]] | None
flow_factory: str
num_flow_layers: int
autoencoder_factory: str
num_autoencoder_layers: int
convnextformer_num_conv_per_transformer: int = 3
wavvae_path: str | None = None
fsq_levels: list[int] | None = None
bottleneck_size: int | None = None
latent_stride: int = 2
vae: bool = False
causal_transformer: bool = False
cond_dim: int | None = None
is_causal: bool = False
class PatchVAE(nn.Module):
def __init__(self, cfg: PatchVAEConfig):
super().__init__()
self.flow_net = AdaLNFlowPredictor(
feat_dim=cfg.latent_dim * cfg.latent_stride,
dim=cfg.hidden_dim,
n_layer=cfg.num_flow_layers,
layer_factory=cfg.flow_factory,
cond_dim=cfg.cond_dim,
is_causal=cfg.is_causal,
)
self.autoencoder = AutoEncoder(
cfg.latent_dim * cfg.latent_stride,
cfg.hidden_dim,
cfg.num_autoencoder_layers,
cfg.autoencoder_factory,
out_dim=cfg.cond_dim,
vae=cfg.vae,
bottleneck_size=cfg.bottleneck_size,
convnextformer_num_conv_per_transformer=cfg.convnextformer_num_conv_per_transformer,
is_causal=cfg.is_causal,
)
if cfg.latent_scaling is not None:
mean, std = cfg.latent_scaling
self.register_buffer("mean_latent_scaling", torch.tensor(mean))
self.register_buffer("std_latent_scaling", torch.tensor(std))
else:
self.mean_latent_scaling = None
self.std_latent_scaling = None
self.latent_stride = cfg.latent_stride
self.latent_dim = cfg.latent_dim
self.wavvae = None
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
map_location: str = "cpu",
):
if Path(pretrained_model_name_or_path).exists():
path = pretrained_model_name_or_path
else:
from huggingface_hub import snapshot_download
path = snapshot_download(pretrained_model_name_or_path)
with open(Path(path) / "config.json", "r") as f:
config = json.load(f)
config = PatchVAEConfig(**config)
model = cls(config).to(map_location)
state_dict = load_file(
Path(path) / "model.st",
device=map_location,
)
model.load_state_dict(state_dict, assign=True)
if config.wavvae_path is not None:
from .. import WavVAE
model.wavvae = WavVAE.from_pretrained(config.wavvae_path).to(map_location)
else:
model.wavvae = None
return model
def wavvae_from_pretrained(
self,
pretrained_model_name_or_path: str,
*args,
**kwargs,
):
from .. import WavVAE
self.wavvae = WavVAE.from_pretrained(
pretrained_model_name_or_path,
*args,
**kwargs,
)
def encode(self, wav: torch.Tensor):
assert self.wavvae is not None, (
"please provide WavVAE model to encode from waveform"
)
z = self.wavvae.encode(wav)
zz = self.encode_patch(z)
return zz
def decode(self, patchvae_latent: torch.Tensor, **kwargs):
assert self.wavvae is not None, (
"please provide WavVAE model to decode to waveform"
)
z = self.decode_patch(patchvae_latent, **kwargs)
wav = self.wavvae.decode(z)
return wav
def normalize_z(self, z: torch.Tensor):
if self.mean_latent_scaling is not None:
z = (z - self.mean_latent_scaling) / self.std_latent_scaling
return z
def denormalize_z(self, z: torch.Tensor):
if self.std_latent_scaling is not None:
z = z * self.std_latent_scaling + self.mean_latent_scaling
return z
def encode_patch(self, z: torch.Tensor, deterministic: bool = False):
B, T, D = z.shape
z = self.normalize_z(z)
if self.latent_stride > 1:
z = z[:, : T - T % self.latent_stride]
z = z.reshape(B, T // self.latent_stride, D * self.latent_stride)
return self.autoencoder.encode(z, deterministic=deterministic)
def decode_patch(
self,
latent: torch.Tensor,
cfg: float = 2.0,
num_steps: int = 15,
solver: str = "euler",
sensitivity: str = "adjoint",
temperature: float = 1.0,
**kwargs,
):
with torch.no_grad():
z_cond = self.autoencoder.decode(latent).transpose(1, 2)
if cfg == 1.0:
def solver_fn(t, Xt, *args, **kwargs):
flow = self.flow_net(Xt, z_cond, t.unsqueeze(0))
return flow
else:
z_cond_uncond = torch.cat((z_cond, torch.zeros_like(z_cond)), dim=0)
def solver_fn(t, Xt, *args, **kwargs):
flow = self.flow_net(
Xt.repeat(2, 1, 1), z_cond_uncond, t.unsqueeze(0)
)
cond, uncond = flow.chunk(2, dim=0)
return uncond + cfg * (cond - uncond)
with suppress_stdout():
node_ = NeuralODE(
solver_fn,
solver=solver,
sensitivity=sensitivity,
**kwargs,
)
t_span = torch.linspace(0, 1, num_steps + 1, device=z_cond.device)
patch_dim = self.latent_dim * self.latent_stride
x0 = torch.randn(
z_cond.shape[0],
patch_dim,
z_cond.shape[2],
device=z_cond.device,
)
traj = node_.trajectory(
x0 * temperature,
t_span=t_span,
)
y_hat = traj[-1]
y_hat = y_hat.transpose(1, 2)
B, T, D = y_hat.shape
y_hat = y_hat.reshape(B, T * self.latent_stride, D // self.latent_stride)
y_hat = self.denormalize_z(y_hat)
return y_hat
def forward(
self,
z: torch.Tensor,
t: torch.Tensor,
drop_cond_rate: float = 0.0,
drop_vae_rate: float = 0.0,
sigma: float = 1e-4,
):
z = self.normalize_z(z)
B, T, D = z.shape
if self.latent_stride > 1:
z = z.reshape(B, T // self.latent_stride, D * self.latent_stride)
prior, ae_loss = self.autoencoder(z, drop_vae_rate=drop_vae_rate)
if drop_cond_rate > 0.0:
to_drop = torch.rand(prior.shape[0], device=prior.device) < drop_cond_rate
prior[to_drop] = 0.0
x0 = torch.randn_like(z)
x1 = z
flow_target = x1 - (1 - sigma) * x0
alpha = (1 - (1 - sigma) * t).view(-1, 1, 1)
xt = alpha * x0 + t.view(-1, 1, 1) * x1
pred = self.flow_net(
xt.transpose(1, 2),
prior.transpose(1, 2),
t,
)
flow_loss = nn.functional.mse_loss(flow_target.transpose(1, 2), pred)
return flow_loss, ae_loss, prior