STAR / star /models /adapter /projector.py
MM-MVR's picture
Upload files
97bc03d verified
import torch
import torch.nn as nn
class MlpProjector(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.model_name == "MLP_GELU":
mlp_depth = cfg.get("depth", 1)
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
else:
raise ValueError(f"Unknown projector type: {cfg.model_name}")
self.layers = modules
def forward(self, x):
return self.layers(x)