Spaces:
Running
Running
| """ | |
| Model Loading Utilities | |
| Provides comprehensive model loading capabilities for various formats and sources | |
| including PyTorch models, Safetensors, and Hugging Face transformers. | |
| """ | |
| import os | |
| import logging | |
| import asyncio | |
| from typing import Dict, Any, Optional, Union, List | |
| from pathlib import Path | |
| import json | |
| import requests | |
| from urllib.parse import urlparse | |
| import tempfile | |
| import shutil | |
| import torch | |
| import torch.nn as nn | |
| from transformers import ( | |
| AutoModel, AutoTokenizer, AutoConfig, AutoImageProcessor, | |
| AutoFeatureExtractor, AutoProcessor, AutoModelForCausalLM, | |
| AutoModelForSeq2SeqLM | |
| ) | |
| from safetensors import safe_open | |
| from safetensors.torch import load_file as load_safetensors | |
| import numpy as np | |
| from PIL import Image | |
| logger = logging.getLogger(__name__) | |
| # Custom model configurations for special architectures | |
| CUSTOM_MODEL_CONFIGS = { | |
| 'ti2v': { | |
| 'model_type': 'ti2v', | |
| 'architecture': 'TI2VModel', | |
| 'modalities': ['text', 'vision'], | |
| 'supports_generation': True, | |
| 'is_multimodal': True | |
| }, | |
| 'diffusion': { | |
| 'model_type': 'diffusion', | |
| 'architecture': 'DiffusionModel', | |
| 'modalities': ['vision', 'text'], | |
| 'supports_generation': True, | |
| 'is_multimodal': True | |
| } | |
| } | |
| class ModelLoader: | |
| """ | |
| Comprehensive model loader supporting multiple formats and sources | |
| """ | |
| def __init__(self): | |
| self.supported_formats = { | |
| '.pt': 'pytorch', | |
| '.pth': 'pytorch', | |
| '.bin': 'pytorch', | |
| '.safetensors': 'safetensors', | |
| '.onnx': 'onnx', | |
| '.h5': 'keras', | |
| '.pkl': 'pickle', | |
| '.joblib': 'joblib' | |
| } | |
| self.modality_keywords = { | |
| 'text': ['bert', 'gpt', 'roberta', 'electra', 'deberta', 'xlm', 'xlnet', 't5', 'bart'], | |
| 'vision': ['vit', 'resnet', 'efficientnet', 'convnext', 'swin', 'deit', 'beit'], | |
| 'multimodal': ['clip', 'blip', 'albef', 'flava', 'layoutlm', 'donut'], | |
| 'audio': ['wav2vec', 'hubert', 'whisper', 'speech_t5'] | |
| } | |
| async def load_model(self, source: str, **kwargs) -> Dict[str, Any]: | |
| """ | |
| Load a model from various sources | |
| Args: | |
| source: Model source (file path, HF repo, URL) | |
| **kwargs: Additional loading parameters | |
| Returns: | |
| Dictionary containing model, tokenizer/processor, and metadata | |
| """ | |
| try: | |
| logger.info(f"Loading model from: {source}") | |
| # Determine source type | |
| if self._is_url(source): | |
| return await self._load_from_url(source, **kwargs) | |
| elif self._is_huggingface_repo(source): | |
| return await self._load_from_huggingface(source, **kwargs) | |
| elif Path(source).exists(): | |
| return await self._load_from_file(source, **kwargs) | |
| else: | |
| raise ValueError(f"Invalid model source: {source}") | |
| except Exception as e: | |
| logger.error(f"Error loading model from {source}: {str(e)}") | |
| raise | |
| async def get_model_info(self, source: str) -> Dict[str, Any]: | |
| """ | |
| Get model information without loading the full model | |
| Args: | |
| source: Model source | |
| Returns: | |
| Model metadata and information | |
| """ | |
| try: | |
| info = { | |
| 'source': source, | |
| 'format': 'unknown', | |
| 'modality': 'unknown', | |
| 'architecture': None, | |
| 'parameters': None, | |
| 'size_mb': None | |
| } | |
| if Path(source).exists(): | |
| file_path = Path(source) | |
| info['size_mb'] = file_path.stat().st_size / (1024 * 1024) | |
| info['format'] = self.supported_formats.get(file_path.suffix, 'unknown') | |
| # Try to extract more info based on format | |
| if info['format'] == 'safetensors': | |
| info.update(await self._get_safetensors_info(source)) | |
| elif info['format'] == 'pytorch': | |
| info.update(await self._get_pytorch_info(source)) | |
| elif self._is_huggingface_repo(source): | |
| info.update(await self._get_huggingface_info(source)) | |
| # Detect modality from model name/architecture | |
| info['modality'] = self._detect_modality(source, info.get('architecture', '')) | |
| return info | |
| except Exception as e: | |
| logger.warning(f"Error getting model info for {source}: {str(e)}") | |
| return {'source': source, 'error': str(e)} | |
| def _is_url(self, source: str) -> bool: | |
| """Check if source is a URL""" | |
| try: | |
| result = urlparse(source) | |
| return all([result.scheme, result.netloc]) | |
| except: | |
| return False | |
| def _is_huggingface_repo(self, source: str) -> bool: | |
| """Check if source is a Hugging Face repository""" | |
| # Simple heuristic: contains '/' but not a file extension | |
| return '/' in source and not any(source.endswith(ext) for ext in self.supported_formats.keys()) | |
| def _detect_modality(self, source: str, architecture: str) -> str: | |
| """Detect model modality from source and architecture""" | |
| text = (source + ' ' + architecture).lower() | |
| for modality, keywords in self.modality_keywords.items(): | |
| if any(keyword in text for keyword in keywords): | |
| return modality | |
| return 'unknown' | |
| async def _load_from_file(self, file_path: str, **kwargs) -> Dict[str, Any]: | |
| """Load model from local file""" | |
| file_path = Path(file_path) | |
| format_type = self.supported_formats.get(file_path.suffix, 'unknown') | |
| if format_type == 'safetensors': | |
| return await self._load_safetensors(file_path, **kwargs) | |
| elif format_type == 'pytorch': | |
| return await self._load_pytorch(file_path, **kwargs) | |
| else: | |
| raise ValueError(f"Unsupported format: {format_type}") | |
| async def _load_from_url(self, url: str, **kwargs) -> Dict[str, Any]: | |
| """Load model from URL""" | |
| # Download to temporary file | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_file: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| for chunk in response.iter_content(chunk_size=8192): | |
| tmp_file.write(chunk) | |
| tmp_path = tmp_file.name | |
| try: | |
| # Load from temporary file | |
| result = await self._load_from_file(tmp_path, **kwargs) | |
| result['source_url'] = url | |
| return result | |
| finally: | |
| # Cleanup temporary file | |
| os.unlink(tmp_path) | |
| async def _load_from_huggingface(self, repo_id: str, **kwargs) -> Dict[str, Any]: | |
| """Load model from Hugging Face repository""" | |
| try: | |
| # Get HF token from multiple sources | |
| hf_token = ( | |
| kwargs.get('token') or | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_TOKEN') or | |
| os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| ) | |
| logger.info(f"Loading model {repo_id} with token: {'Yes' if hf_token else 'No'}") | |
| # Load configuration first with timeout | |
| trust_remote_code = kwargs.get('trust_remote_code', False) | |
| logger.info(f"Loading config for {repo_id} with trust_remote_code={trust_remote_code}") | |
| try: | |
| config = AutoConfig.from_pretrained( | |
| repo_id, | |
| trust_remote_code=trust_remote_code, | |
| token=hf_token, | |
| timeout=30 # 30 second timeout | |
| ) | |
| logger.info(f"Successfully loaded config for {repo_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to load config for {repo_id}: {e}") | |
| raise ValueError(f"Could not load model configuration: {str(e)}") | |
| # Load model with proper device handling | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Check if this is a large model and warn | |
| model_size_gb = self._estimate_model_size(config) | |
| if model_size_gb > 10: | |
| logger.warning(f"Large model detected ({model_size_gb:.1f}GB estimated). This may take several minutes to load.") | |
| # Check for custom architectures that need special handling | |
| model_type = getattr(config, 'model_type', None) | |
| # Try different loading strategies for different model types | |
| model = None | |
| loading_error = None | |
| # Special handling for ti2v and other custom architectures | |
| if model_type in CUSTOM_MODEL_CONFIGS: | |
| try: | |
| logger.info(f"Loading custom architecture {model_type} for {repo_id}...") | |
| model = await self._load_custom_architecture(repo_id, config, hf_token, trust_remote_code, **kwargs) | |
| except Exception as e: | |
| logger.warning(f"Custom architecture loading failed: {e}") | |
| loading_error = str(e) | |
| # Strategy 1: Try AutoModel (most common) if not already loaded | |
| if model is None: | |
| try: | |
| logger.info(f"Attempting to load {repo_id} with AutoModel...") | |
| model = AutoModel.from_pretrained( | |
| repo_id, | |
| config=config, | |
| torch_dtype=kwargs.get('torch_dtype', torch.float32), | |
| trust_remote_code=trust_remote_code, | |
| token=hf_token, | |
| low_cpu_mem_usage=True, | |
| timeout=120 # 2 minute timeout for model loading | |
| ) | |
| logger.info(f"Successfully loaded {repo_id} with AutoModel") | |
| except Exception as e: | |
| loading_error = str(e) | |
| logger.warning(f"AutoModel failed for {repo_id}: {e}") | |
| # Strategy 2: Try specific model classes for known types | |
| if model is None: | |
| model = await self._try_specific_model_classes(repo_id, config, hf_token, trust_remote_code, kwargs) | |
| # Strategy 3: Try with trust_remote_code if not already enabled | |
| if model is None and not trust_remote_code: | |
| try: | |
| logger.info(f"Trying {repo_id} with trust_remote_code=True") | |
| # For Gemma 3 models, try AutoModelForCausalLM specifically | |
| if 'gemma-3' in repo_id.lower() or 'gemma3' in str(config).lower(): | |
| from transformers import AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained( | |
| repo_id, | |
| config=config, | |
| torch_dtype=kwargs.get('torch_dtype', torch.float32), | |
| trust_remote_code=True, | |
| token=hf_token, | |
| low_cpu_mem_usage=True | |
| ) | |
| else: | |
| model = AutoModel.from_pretrained( | |
| repo_id, | |
| config=config, | |
| torch_dtype=kwargs.get('torch_dtype', torch.float32), | |
| trust_remote_code=True, | |
| token=hf_token, | |
| low_cpu_mem_usage=True | |
| ) | |
| logger.info(f"Successfully loaded {repo_id} with trust_remote_code=True") | |
| except Exception as e: | |
| logger.warning(f"Loading with trust_remote_code=True failed: {e}") | |
| if model is None: | |
| raise ValueError(f"Could not load model {repo_id}. Last error: {loading_error}") | |
| # Move to device manually | |
| model = model.to(device) | |
| # Load appropriate processor/tokenizer | |
| processor = None | |
| try: | |
| # Try different processor types | |
| for processor_class in [AutoTokenizer, AutoImageProcessor, AutoFeatureExtractor, AutoProcessor]: | |
| try: | |
| processor = processor_class.from_pretrained(repo_id, token=hf_token) | |
| break | |
| except: | |
| continue | |
| except Exception as e: | |
| logger.warning(f"Could not load processor for {repo_id}: {e}") | |
| return { | |
| 'model': model, | |
| 'processor': processor, | |
| 'config': config, | |
| 'source': repo_id, | |
| 'format': 'huggingface', | |
| 'architecture': config.architectures[0] if hasattr(config, 'architectures') and config.architectures else None, | |
| 'modality': self._detect_modality(repo_id, str(config.architectures) if hasattr(config, 'architectures') else ''), | |
| 'parameters': sum(p.numel() for p in model.parameters()) if hasattr(model, 'parameters') else None | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading from Hugging Face repo {repo_id}: {str(e)}") | |
| raise | |
| async def _load_custom_architecture(self, repo_id: str, config, hf_token: str, trust_remote_code: bool, **kwargs): | |
| """Load models with custom architectures like ti2v""" | |
| try: | |
| model_type = getattr(config, 'model_type', None) | |
| logger.info(f"Loading custom architecture: {model_type}") | |
| if model_type == 'ti2v': | |
| # For ti2v models, we need to create a wrapper that can work with our distillation | |
| return await self._load_ti2v_model(repo_id, config, hf_token, trust_remote_code, **kwargs) | |
| else: | |
| # For other custom architectures, try with trust_remote_code | |
| logger.info(f"Attempting to load custom model {repo_id} with trust_remote_code=True") | |
| # Try different model classes | |
| model_classes = [AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM] | |
| for model_class in model_classes: | |
| try: | |
| model = model_class.from_pretrained( | |
| repo_id, | |
| config=config, | |
| trust_remote_code=True, # Force trust_remote_code for custom architectures | |
| token=hf_token, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.float32 | |
| ) | |
| logger.info(f"Successfully loaded {repo_id} with {model_class.__name__}") | |
| return model | |
| except Exception as e: | |
| logger.warning(f"{model_class.__name__} failed for {repo_id}: {e}") | |
| continue | |
| raise ValueError(f"All loading strategies failed for custom architecture {model_type}") | |
| except Exception as e: | |
| logger.error(f"Error loading custom architecture: {e}") | |
| raise | |
| async def _load_ti2v_model(self, repo_id: str, config, hf_token: str, trust_remote_code: bool, **kwargs): | |
| """Special handling for ti2v (Text-to-Image/Video) models""" | |
| try: | |
| logger.info(f"Loading ti2v model: {repo_id}") | |
| # For ti2v models, we'll create a wrapper that extracts text features | |
| # This allows us to use them in knowledge distillation | |
| # Try to load with trust_remote_code=True (required for custom architectures) | |
| model = AutoModel.from_pretrained( | |
| repo_id, | |
| config=config, | |
| trust_remote_code=True, | |
| token=hf_token, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.float32 | |
| ) | |
| # Create a wrapper that can extract features for distillation | |
| class TI2VWrapper(torch.nn.Module): | |
| def __init__(self, base_model): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.config = base_model.config | |
| def forward(self, input_ids=None, attention_mask=None, **kwargs): | |
| # Extract text encoder features if available | |
| if hasattr(self.base_model, 'text_encoder'): | |
| return self.base_model.text_encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| elif hasattr(self.base_model, 'encoder'): | |
| return self.base_model.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| else: | |
| # Fallback: try to get some meaningful representation | |
| return self.base_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) | |
| wrapped_model = TI2VWrapper(model) | |
| logger.info(f"Successfully wrapped ti2v model: {repo_id}") | |
| return wrapped_model | |
| except Exception as e: | |
| logger.error(f"Error loading ti2v model {repo_id}: {e}") | |
| raise | |
| async def _load_safetensors(self, file_path: Path, **kwargs) -> Dict[str, Any]: | |
| """Load model from Safetensors format""" | |
| try: | |
| # Load tensors | |
| tensors = {} | |
| with safe_open(file_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| tensors[key] = f.get_tensor(key) | |
| # Try to reconstruct model architecture | |
| model = self._reconstruct_model_from_tensors(tensors) | |
| return { | |
| 'model': model, | |
| 'tensors': tensors, | |
| 'source': str(file_path), | |
| 'format': 'safetensors', | |
| 'parameters': sum(tensor.numel() for tensor in tensors.values()), | |
| 'tensor_keys': list(tensors.keys()) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading Safetensors file {file_path}: {str(e)}") | |
| raise | |
| async def _load_pytorch(self, file_path: Path, **kwargs) -> Dict[str, Any]: | |
| """Load PyTorch model""" | |
| try: | |
| # Load checkpoint | |
| checkpoint = torch.load(file_path, map_location='cpu') | |
| # Extract model and metadata | |
| if isinstance(checkpoint, dict): | |
| model = checkpoint.get('model', checkpoint.get('state_dict', checkpoint)) | |
| metadata = {k: v for k, v in checkpoint.items() if k not in ['model', 'state_dict']} | |
| else: | |
| model = checkpoint | |
| metadata = {} | |
| return { | |
| 'model': model, | |
| 'metadata': metadata, | |
| 'source': str(file_path), | |
| 'format': 'pytorch', | |
| 'parameters': sum(tensor.numel() for tensor in model.values()) if isinstance(model, dict) else None | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading PyTorch file {file_path}: {str(e)}") | |
| raise | |
| def _reconstruct_model_from_tensors(self, tensors: Dict[str, torch.Tensor]) -> nn.Module: | |
| """ | |
| Attempt to reconstruct a PyTorch model from tensor dictionary | |
| This is a simplified implementation - in practice, this would need | |
| more sophisticated architecture detection | |
| """ | |
| class GenericModel(nn.Module): | |
| def __init__(self, tensors): | |
| super().__init__() | |
| self.tensors = nn.ParameterDict() | |
| for name, tensor in tensors.items(): | |
| self.tensors[name.replace('.', '_')] = nn.Parameter(tensor) | |
| def forward(self, x): | |
| # Placeholder forward pass | |
| return x | |
| return GenericModel(tensors) | |
| async def _get_safetensors_info(self, file_path: str) -> Dict[str, Any]: | |
| """Get information from Safetensors file""" | |
| try: | |
| info = {} | |
| with safe_open(file_path, framework="pt", device="cpu") as f: | |
| keys = list(f.keys()) | |
| info['tensor_count'] = len(keys) | |
| info['tensor_keys'] = keys[:10] # First 10 keys | |
| # Estimate parameters | |
| total_params = 0 | |
| for key in keys: | |
| tensor = f.get_tensor(key) | |
| total_params += tensor.numel() | |
| info['parameters'] = total_params | |
| return info | |
| except Exception as e: | |
| logger.warning(f"Error getting Safetensors info: {e}") | |
| return {} | |
| async def _get_pytorch_info(self, file_path: str) -> Dict[str, Any]: | |
| """Get information from PyTorch file""" | |
| try: | |
| checkpoint = torch.load(file_path, map_location='cpu') | |
| info = {} | |
| if isinstance(checkpoint, dict): | |
| info['keys'] = list(checkpoint.keys()) | |
| # Look for model/state_dict | |
| model_data = checkpoint.get('model', checkpoint.get('state_dict', checkpoint)) | |
| if isinstance(model_data, dict): | |
| info['parameters'] = sum(tensor.numel() for tensor in model_data.values()) | |
| info['layer_count'] = len(model_data) | |
| return info | |
| except Exception as e: | |
| logger.warning(f"Error getting PyTorch info: {e}") | |
| return {} | |
| async def _get_huggingface_info(self, repo_id: str) -> Dict[str, Any]: | |
| """Get information from Hugging Face repository""" | |
| try: | |
| hf_token = ( | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_TOKEN') or | |
| os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| ) | |
| config = AutoConfig.from_pretrained(repo_id, token=hf_token) | |
| info = { | |
| 'architecture': config.architectures[0] if hasattr(config, 'architectures') and config.architectures else None, | |
| 'model_type': getattr(config, 'model_type', None), | |
| 'hidden_size': getattr(config, 'hidden_size', None), | |
| 'num_layers': getattr(config, 'num_hidden_layers', getattr(config, 'num_layers', None)), | |
| 'vocab_size': getattr(config, 'vocab_size', None) | |
| } | |
| return info | |
| except Exception as e: | |
| logger.warning(f"Error getting Hugging Face info: {e}") | |
| return {} | |
| async def _try_specific_model_classes(self, repo_id: str, config, hf_token: str, trust_remote_code: bool, kwargs: Dict[str, Any]): | |
| """Try loading with specific model classes for known architectures""" | |
| from transformers import ( | |
| AutoModelForCausalLM, AutoModelForSequenceClassification, | |
| AutoModelForTokenClassification, AutoModelForQuestionAnswering, | |
| AutoModelForMaskedLM, AutoModelForImageClassification, | |
| AutoModelForObjectDetection, AutoModelForSemanticSegmentation, | |
| AutoModelForImageSegmentation, AutoModelForDepthEstimation, | |
| AutoModelForZeroShotImageClassification | |
| ) | |
| # Map model types to appropriate AutoModel classes | |
| model_type = getattr(config, 'model_type', '').lower() | |
| architecture = getattr(config, 'architectures', []) | |
| arch_str = str(architecture).lower() if architecture else '' | |
| model_classes_to_try = [] | |
| # Determine appropriate model classes based on model type and architecture | |
| if 'siglip' in model_type or 'siglip' in arch_str: | |
| # SigLIP models - try vision-related classes | |
| model_classes_to_try = [ | |
| AutoModelForImageClassification, | |
| AutoModelForZeroShotImageClassification, | |
| AutoModel | |
| ] | |
| elif 'clip' in model_type or 'clip' in arch_str: | |
| model_classes_to_try = [AutoModelForZeroShotImageClassification, AutoModel] | |
| elif 'vit' in model_type or 'vision' in model_type: | |
| model_classes_to_try = [AutoModelForImageClassification, AutoModel] | |
| elif 'bert' in model_type or 'roberta' in model_type: | |
| model_classes_to_try = [AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModel] | |
| elif 'gemma' in model_type or 'gemma' in arch_str: | |
| # Gemma models (including Gemma 3) - try causal LM classes | |
| model_classes_to_try = [AutoModelForCausalLM, AutoModel] | |
| elif 'gpt' in model_type or 'llama' in model_type: | |
| model_classes_to_try = [AutoModelForCausalLM, AutoModel] | |
| else: | |
| # Generic fallback | |
| model_classes_to_try = [ | |
| AutoModelForCausalLM, # Try causal LM first for newer models | |
| AutoModelForSequenceClassification, | |
| AutoModelForImageClassification, | |
| AutoModel | |
| ] | |
| # Try each model class | |
| for model_class in model_classes_to_try: | |
| try: | |
| logger.info(f"Trying {repo_id} with {model_class.__name__}") | |
| model = model_class.from_pretrained( | |
| repo_id, | |
| config=config, | |
| torch_dtype=kwargs.get('torch_dtype', torch.float32), | |
| trust_remote_code=trust_remote_code, | |
| token=hf_token, | |
| low_cpu_mem_usage=True | |
| ) | |
| logger.info(f"Successfully loaded {repo_id} with {model_class.__name__}") | |
| return model | |
| except Exception as e: | |
| logger.debug(f"{model_class.__name__} failed for {repo_id}: {e}") | |
| continue | |
| return None | |
| async def load_trained_student(self, model_path: str) -> Dict[str, Any]: | |
| """Load a previously trained student model for retraining""" | |
| try: | |
| # Check if it's a Hugging Face model (starts with organization/) | |
| if '/' in model_path and not Path(model_path).exists(): | |
| # This is likely a Hugging Face repository | |
| return await self._load_student_from_huggingface(model_path) | |
| # Local model path | |
| model_dir = Path(model_path) | |
| # Check if it's a trained student model | |
| config_path = model_dir / "config.json" | |
| if not config_path.exists(): | |
| # Try alternative naming | |
| safetensors_files = list(model_dir.glob("*.safetensors")) | |
| if safetensors_files: | |
| config_path = safetensors_files[0].with_suffix('_config.json') | |
| if not config_path.exists(): | |
| raise ValueError("No configuration file found for student model") | |
| # Load configuration | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| # Verify it's a student model | |
| if not config.get('is_student_model', False): | |
| raise ValueError("This is not a trained student model") | |
| # Load training history | |
| history_path = model_dir / "training_history.json" | |
| if not history_path.exists(): | |
| # Try alternative naming | |
| safetensors_files = list(model_dir.glob("*.safetensors")) | |
| if safetensors_files: | |
| history_path = safetensors_files[0].with_suffix('_training_history.json') | |
| training_history = {} | |
| if history_path.exists(): | |
| with open(history_path, 'r') as f: | |
| training_history = json.load(f) | |
| # Load model weights | |
| model_file = None | |
| for ext in ['.safetensors', '.bin', '.pt']: | |
| potential_file = model_dir / f"student_model{ext}" | |
| if potential_file.exists(): | |
| model_file = potential_file | |
| break | |
| if not model_file: | |
| # Look for any model file | |
| for ext in ['.safetensors', '.bin', '.pt']: | |
| files = list(model_dir.glob(f"*{ext}")) | |
| if files: | |
| model_file = files[0] | |
| break | |
| if not model_file: | |
| raise ValueError("No model file found") | |
| return { | |
| 'type': 'trained_student', | |
| 'path': str(model_path), | |
| 'config': config, | |
| 'training_history': training_history, | |
| 'model_file': str(model_file), | |
| 'can_be_retrained': config.get('can_be_retrained', True), | |
| 'original_teachers': training_history.get('retraining_info', {}).get('original_teachers', []), | |
| 'recommended_lr': training_history.get('retraining_info', {}).get('recommended_learning_rate', 1e-5), | |
| 'modalities': config.get('modalities', ['text']), | |
| 'architecture': config.get('architecture', 'unknown') | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading trained student model: {e}") | |
| raise | |
| async def _load_student_from_huggingface(self, repo_id: str) -> Dict[str, Any]: | |
| """Load a student model from Hugging Face repository""" | |
| try: | |
| # Get HF token | |
| hf_token = ( | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_TOKEN') or | |
| os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| ) | |
| logger.info(f"Loading student model from Hugging Face: {repo_id}") | |
| # Load configuration | |
| config = AutoConfig.from_pretrained(repo_id, token=hf_token) | |
| # Try to load the model to verify it exists and is accessible | |
| model = await self._load_from_huggingface(repo_id, token=hf_token) | |
| # Check if it's marked as a student model (optional) | |
| is_student = config.get('is_student_model', False) | |
| return { | |
| 'type': 'huggingface_student', | |
| 'path': repo_id, | |
| 'config': config.__dict__ if hasattr(config, '__dict__') else {}, | |
| 'training_history': {}, # HF models may not have our training history | |
| 'model_file': repo_id, # For HF models, this is the repo ID | |
| 'can_be_retrained': True, | |
| 'original_teachers': [], # Unknown for external models | |
| 'recommended_lr': 1e-5, # Default learning rate | |
| 'modalities': ['text'], # Default, could be enhanced | |
| 'architecture': getattr(config, 'architectures', ['unknown'])[0] if hasattr(config, 'architectures') else 'unknown', | |
| 'is_huggingface': True | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading student model from Hugging Face: {e}") | |
| raise ValueError(f"Could not load student model from Hugging Face: {str(e)}") | |
| async def load_trained_student_from_space(self, space_name: str) -> Dict[str, Any]: | |
| """Load a student model from a Hugging Face Space""" | |
| try: | |
| # Get HF token | |
| hf_token = ( | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_TOKEN') or | |
| os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| ) | |
| logger.info(f"Loading student model from Hugging Face Space: {space_name}") | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=hf_token) | |
| # List files in the Space to find model files | |
| try: | |
| files = api.list_repo_files(space_name, repo_type="space") | |
| # Look for model files in models directory | |
| model_files = [f for f in files if f.startswith('models/') and f.endswith(('.safetensors', '.bin', '.pt'))] | |
| if not model_files: | |
| # Look for model files in root | |
| model_files = [f for f in files if f.endswith(('.safetensors', '.bin', '.pt'))] | |
| if not model_files: | |
| raise ValueError(f"No model files found in Space {space_name}") | |
| # Use the first model file found | |
| model_file = model_files[0] | |
| logger.info(f"Found model file in Space: {model_file}") | |
| # For now, we'll treat Space models as external HF models | |
| # In the future, we could download and cache them locally | |
| return { | |
| 'type': 'space_student', | |
| 'path': space_name, | |
| 'config': {}, # Space models may not have our config format | |
| 'training_history': {}, # Unknown for space models | |
| 'model_file': model_file, | |
| 'can_be_retrained': True, | |
| 'original_teachers': [], # Unknown for external models | |
| 'recommended_lr': 1e-5, # Default learning rate | |
| 'modalities': ['text'], # Default, could be enhanced | |
| 'architecture': 'unknown', | |
| 'is_space': True, | |
| 'space_name': space_name, | |
| 'available_models': model_files | |
| } | |
| except Exception as e: | |
| logger.error(f"Error accessing Space files: {e}") | |
| # Fallback: treat as a regular HF model | |
| return await self._load_student_from_huggingface(space_name) | |
| except Exception as e: | |
| logger.error(f"Error loading student model from Space: {e}") | |
| raise ValueError(f"Could not load student model from Space: {str(e)}") | |
| def _estimate_model_size(self, config) -> float: | |
| """Estimate model size in GB based on configuration""" | |
| try: | |
| # Get basic parameters | |
| hidden_size = getattr(config, 'hidden_size', 768) | |
| num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'num_layers', 12)) | |
| vocab_size = getattr(config, 'vocab_size', 50000) | |
| # Rough estimation: parameters * 4 bytes (float32) / 1GB | |
| # This is a very rough estimate | |
| embedding_params = vocab_size * hidden_size | |
| layer_params = num_layers * (hidden_size * hidden_size * 4) # Simplified | |
| total_params = embedding_params + layer_params | |
| # Convert to GB (4 bytes per parameter for float32) | |
| size_gb = (total_params * 4) / (1024 ** 3) | |
| return max(size_gb, 0.1) # Minimum 0.1GB | |
| except Exception: | |
| return 1.0 # Default 1GB if estimation fails | |
| def validate_model_compatibility(self, models: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """ | |
| Validate that multiple models are compatible for knowledge distillation | |
| Args: | |
| models: List of loaded model dictionaries | |
| Returns: | |
| Validation result with compatibility information | |
| """ | |
| if not models: | |
| return {'compatible': False, 'reason': 'No models provided'} | |
| if len(models) < 2: | |
| return {'compatible': False, 'reason': 'At least 2 models required for distillation'} | |
| # Check modality compatibility | |
| modalities = [model.get('modality', 'unknown') for model in models] | |
| unique_modalities = set(modalities) | |
| # Allow same modality or multimodal combinations | |
| if len(unique_modalities) == 1 and 'unknown' not in unique_modalities: | |
| compatibility_type = 'same_modality' | |
| elif 'multimodal' in unique_modalities or len(unique_modalities) > 1: | |
| compatibility_type = 'cross_modal' | |
| else: | |
| return {'compatible': False, 'reason': 'Unknown modalities detected'} | |
| return { | |
| 'compatible': True, | |
| 'type': compatibility_type, | |
| 'modalities': list(unique_modalities), | |
| 'model_count': len(models), | |
| 'total_parameters': sum(model.get('parameters', 0) for model in models if model.get('parameters')) | |
| } | |