Spaces:
Sleeping
Sleeping
| # models/nli_classifier.py | |
| from transformers import pipeline | |
| import torch | |
| from collections import Counter | |
| class NLIClassifier: | |
| _instance = None | |
| _initialized = False | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: | |
| return | |
| try: | |
| print("Loading NLI models (this may take a moment)...") | |
| device = 0 if torch.cuda.is_available() else -1 | |
| self.models = [] | |
| # Model 1: RoBERTa-large-MNLI (most accurate) | |
| try: | |
| self.models.append({ | |
| 'name': 'roberta-large-mnli', | |
| 'pipeline': pipeline( | |
| "text-classification", | |
| model="roberta-large-mnli", | |
| device=device | |
| ), | |
| 'weight': 0.5 | |
| }) | |
| print("β Loaded RoBERTa-large-MNLI") | |
| except Exception as e: | |
| print(f"β Failed to load RoBERTa-large-MNLI: {e}") | |
| # Model 2: DeBERTa-v3-large MNLI fine-tuned (use pre-trained version) | |
| try: | |
| self.models.append({ | |
| 'name': 'deberta-v3-large-mnli', | |
| 'pipeline': pipeline( | |
| "text-classification", | |
| model="MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli", | |
| device=device | |
| ), | |
| 'weight': 0.5 | |
| }) | |
| print("β Loaded DeBERTa-v3-large-MNLI") | |
| except Exception as e: | |
| print(f"β Failed to load DeBERTa-v3-large-MNLI: {e}") | |
| if not self.models: | |
| # Fallback to BART if both fail | |
| try: | |
| self.models.append({ | |
| 'name': 'bart-large-mnli', | |
| 'pipeline': pipeline( | |
| "zero-shot-classification", | |
| model="facebook/bart-large-mnli", | |
| device=device | |
| ), | |
| 'weight': 1.0 | |
| }) | |
| print("β Loaded BART-large-MNLI (fallback)") | |
| except Exception as e: | |
| print(f"β Failed to load any NLI model: {e}") | |
| raise Exception("No NLI models loaded successfully") | |
| # Normalize weights | |
| total_weight = sum(m['weight'] for m in self.models) | |
| for model in self.models: | |
| model['weight'] /= total_weight | |
| self._initialized = True | |
| print(f"β Successfully loaded {len(self.models)} NLI model(s)") | |
| except Exception as e: | |
| print(f"Error loading NLI models: {e}") | |
| self.models = [] | |
| self._initialized = False | |
| def classify(self, claim, evidence): | |
| """Classify relationship between claim and evidence using ensemble""" | |
| if not self.models: | |
| return { | |
| 'label': 'NEUTRAL', | |
| 'confidence': 0.5, | |
| 'model_votes': {} | |
| } | |
| try: | |
| results = [] | |
| model_votes = {} | |
| for model_info in self.models: | |
| try: | |
| pipeline_obj = model_info['pipeline'] | |
| model_name = model_info['name'] | |
| # Handle different pipeline types | |
| if 'bart' in model_name: | |
| result = pipeline_obj( | |
| evidence, | |
| candidate_labels=["entailment", "contradiction", "neutral"], | |
| hypothesis_template="This example is {}." | |
| ) | |
| label = result['labels'][0] | |
| confidence = result['scores'][0] | |
| else: | |
| # Standard NLI: premise [SEP] hypothesis | |
| input_text = f"{evidence} [SEP] {claim}" | |
| result = pipeline_obj(input_text)[0] | |
| label = result['label'] | |
| confidence = result['score'] | |
| # Map labels | |
| label_mapping = { | |
| 'ENTAILMENT': 'ENTAILMENT', | |
| 'CONTRADICTION': 'CONTRADICTION', | |
| 'NEUTRAL': 'NEUTRAL', | |
| 'entailment': 'ENTAILMENT', | |
| 'contradiction': 'CONTRADICTION', | |
| 'neutral': 'NEUTRAL', | |
| 'LABEL_0': 'CONTRADICTION', | |
| 'LABEL_1': 'NEUTRAL', | |
| 'LABEL_2': 'ENTAILMENT' | |
| } | |
| mapped_label = label_mapping.get(label, 'NEUTRAL') | |
| results.append({ | |
| 'label': mapped_label, | |
| 'confidence': confidence, | |
| 'weight': model_info['weight'] | |
| }) | |
| model_votes[model_name] = mapped_label | |
| except Exception as e: | |
| print(f"Error with model {model_info['name']}: {e}") | |
| continue | |
| if not results: | |
| return { | |
| 'label': 'NEUTRAL', | |
| 'confidence': 0.5, | |
| 'model_votes': {} | |
| } | |
| # Weighted voting | |
| weighted_scores = { | |
| 'ENTAILMENT': 0.0, | |
| 'CONTRADICTION': 0.0, | |
| 'NEUTRAL': 0.0 | |
| } | |
| for result in results: | |
| weighted_scores[result['label']] += result['confidence'] * result['weight'] | |
| # Get final label and confidence | |
| final_label = max(weighted_scores, key=weighted_scores.get) | |
| total_score = sum(weighted_scores.values()) | |
| final_confidence = weighted_scores[final_label] / total_score if total_score > 0 else 0.5 | |
| return { | |
| 'label': final_label, | |
| 'confidence': final_confidence, | |
| 'model_votes': model_votes, | |
| 'weighted_scores': weighted_scores | |
| } | |
| except Exception as e: | |
| print(f"NLI classification error: {e}") | |
| return { | |
| 'label': 'NEUTRAL', | |
| 'confidence': 0.5, | |
| 'model_votes': {} | |
| } | |