|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
import torch |
|
|
import json |
|
|
from model import LegNet |
|
|
|
|
|
class CellTypeModelLoader: |
|
|
def __init__(self, repo_id="Ni-os/MPRALegNet |
|
|
|
|
|
"): |
|
|
self.repo_id = repo_id |
|
|
self.available_cell_types = { |
|
|
"hepg2": "cell_type_configs/hepg2_config.json", |
|
|
"k562": "cell_type_configs/k562_config.json", |
|
|
"wtc11": "cell_type_configs/wtc11_config.json" |
|
|
} |
|
|
|
|
|
def get_available_cell_types(self): |
|
|
"""Returns list of available cell types""" |
|
|
return list(self.available_cell_types.keys()) |
|
|
|
|
|
def get_device(self): |
|
|
"""Automatically check available devices""" |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
else: |
|
|
return torch.device("cpu") |
|
|
|
|
|
def load_model(self, cell_type, model_config=None, device = None): |
|
|
""" |
|
|
Loads model for specified cell type |
|
|
|
|
|
Args: |
|
|
cell_type (str): one of ['hepg2', 'k562', 'wtc11'] |
|
|
model_config (dict): optional custom model parameters |
|
|
""" |
|
|
|
|
|
if device is None: |
|
|
device = self.get_device() |
|
|
|
|
|
|
|
|
if cell_type.lower() not in self.available_cell_types: |
|
|
available = self.get_available_cell_types() |
|
|
raise ValueError(f"Cell type '{cell_type}' not found. Available: {available}") |
|
|
|
|
|
|
|
|
if model_config is None: |
|
|
config_path = hf_hub_download( |
|
|
repo_id=self.repo_id, |
|
|
filename="config.json" |
|
|
) |
|
|
with open(config_path, 'r') as f: |
|
|
model_config = json.load(f) |
|
|
|
|
|
|
|
|
model = LegNet( |
|
|
in_ch=model_config["in_ch"], |
|
|
stem_ch=model_config["stem_ch"], |
|
|
stem_ks=model_config["stem_ks"], |
|
|
ef_ks=model_config["ef_ks"], |
|
|
ef_block_sizes=model_config["ef_block_sizes"], |
|
|
pool_sizes=model_config["pool_sizes"], |
|
|
resize_factor=model_config["resize_factor"], |
|
|
activation=torch.nn.SiLU) |
|
|
).to(device) |
|
|
|
|
|
|
|
|
cell_config_path = hf_hub_download( |
|
|
repo_id=self.repo_id, |
|
|
filename=self.available_cell_types[cell_type.lower()] |
|
|
) |
|
|
|
|
|
with open(cell_config_path, 'r') as f: |
|
|
cell_config = json.load(f) |
|
|
|
|
|
|
|
|
weights_path = hf_hub_download( |
|
|
repo_id=self.repo_id, |
|
|
filename=cell_config["weights_file"] |
|
|
) |
|
|
|
|
|
|
|
|
state_dict = load_file(weights_path) |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
print(f"✅ Loaded model for {cell_config['cell_type']} cell type") |
|
|
return model |
|
|
|
|
|
|
|
|
def load_cell_type_model(cell_type, repo_id="Ni-os/MPRALegNet", **kwargs): |
|
|
""" |
|
|
Simple function to load model by cell type |
|
|
|
|
|
Example: |
|
|
model = load_cell_type_model("hepg2") |
|
|
model = load_cell_type_model("k562", repo_id="Ni-os/MPRALegNet") |
|
|
""" |
|
|
loader = CellTypeModelLoader(repo_id) |
|
|
return loader.load_model(cell_type, **kwargs) |