MPRALegNet / model_loader.py
Ni-os's picture
Update model_loader.py
a8c4b27 verified
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch
import json
from model import LegNet # Import your model
class CellTypeModelLoader:
def __init__(self, repo_id="Ni-os/MPRALegNet
"):
self.repo_id = repo_id
self.available_cell_types = {
"hepg2": "cell_type_configs/hepg2_config.json",
"k562": "cell_type_configs/k562_config.json",
"wtc11": "cell_type_configs/wtc11_config.json"
}
def get_available_cell_types(self):
"""Returns list of available cell types"""
return list(self.available_cell_types.keys())
def get_device(self):
"""Automatically check available devices"""
if torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
def load_model(self, cell_type, model_config=None, device = None):
"""
Loads model for specified cell type
Args:
cell_type (str): one of ['hepg2', 'k562', 'wtc11']
model_config (dict): optional custom model parameters
"""
if device is None:
device = self.get_device()
# Check if cell type is available
if cell_type.lower() not in self.available_cell_types:
available = self.get_available_cell_types()
raise ValueError(f"Cell type '{cell_type}' not found. Available: {available}")
# Load main model config
if model_config is None:
config_path = hf_hub_download(
repo_id=self.repo_id,
filename="config.json"
)
with open(config_path, 'r') as f:
model_config = json.load(f)
# Create model
model = LegNet(
in_ch=model_config["in_ch"],
stem_ch=model_config["stem_ch"],
stem_ks=model_config["stem_ks"],
ef_ks=model_config["ef_ks"],
ef_block_sizes=model_config["ef_block_sizes"],
pool_sizes=model_config["pool_sizes"],
resize_factor=model_config["resize_factor"],
activation=torch.nn.SiLU)
).to(device)
# Load cell type specific config
cell_config_path = hf_hub_download(
repo_id=self.repo_id,
filename=self.available_cell_types[cell_type.lower()]
)
with open(cell_config_path, 'r') as f:
cell_config = json.load(f)
# Load weights
weights_path = hf_hub_download(
repo_id=self.repo_id,
filename=cell_config["weights_file"]
)
# Load state_dict
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)
print(f"✅ Loaded model for {cell_config['cell_type']} cell type")
return model
# Convenience function for easy usage
def load_cell_type_model(cell_type, repo_id="Ni-os/MPRALegNet", **kwargs):
"""
Simple function to load model by cell type
Example:
model = load_cell_type_model("hepg2")
model = load_cell_type_model("k562", repo_id="Ni-os/MPRALegNet")
"""
loader = CellTypeModelLoader(repo_id)
return loader.load_model(cell_type, **kwargs)