File size: 680 Bytes
97bc03d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
import torch.nn as nn
class MlpProjector(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.model_name == "MLP_GELU":
mlp_depth = cfg.get("depth", 1)
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
else:
raise ValueError(f"Unknown projector type: {cfg.model_name}")
self.layers = modules
def forward(self, x):
return self.layers(x)
|