import torch import torch.nn as nn class MlpProjector(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg if cfg.model_name == "MLP_GELU": mlp_depth = cfg.get("depth", 1) modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) modules = nn.Sequential(*modules) else: raise ValueError(f"Unknown projector type: {cfg.model_name}") self.layers = modules def forward(self, x): return self.layers(x)