| import torch | |
| import torch.nn as nn | |
| class MlpProjector(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| if cfg.model_name == "MLP_GELU": | |
| mlp_depth = cfg.get("depth", 1) | |
| modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] | |
| for _ in range(1, mlp_depth): | |
| modules.append(nn.GELU()) | |
| modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) | |
| modules = nn.Sequential(*modules) | |
| else: | |
| raise ValueError(f"Unknown projector type: {cfg.model_name}") | |
| self.layers = modules | |
| def forward(self, x): | |
| return self.layers(x) | |