TruthCheck-AI / models /nli_classifier.py
CHRISDANIEL145
Initial commit of TruthCheck with Cyber-Noir UI
622a0b7
# models/nli_classifier.py
from transformers import pipeline
import torch
from collections import Counter
class NLIClassifier:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if self._initialized:
return
try:
print("Loading NLI models (this may take a moment)...")
device = 0 if torch.cuda.is_available() else -1
self.models = []
# Model 1: RoBERTa-large-MNLI (most accurate)
try:
self.models.append({
'name': 'roberta-large-mnli',
'pipeline': pipeline(
"text-classification",
model="roberta-large-mnli",
device=device
),
'weight': 0.5
})
print("βœ“ Loaded RoBERTa-large-MNLI")
except Exception as e:
print(f"⚠ Failed to load RoBERTa-large-MNLI: {e}")
# Model 2: DeBERTa-v3-large MNLI fine-tuned (use pre-trained version)
try:
self.models.append({
'name': 'deberta-v3-large-mnli',
'pipeline': pipeline(
"text-classification",
model="MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli",
device=device
),
'weight': 0.5
})
print("βœ“ Loaded DeBERTa-v3-large-MNLI")
except Exception as e:
print(f"⚠ Failed to load DeBERTa-v3-large-MNLI: {e}")
if not self.models:
# Fallback to BART if both fail
try:
self.models.append({
'name': 'bart-large-mnli',
'pipeline': pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=device
),
'weight': 1.0
})
print("βœ“ Loaded BART-large-MNLI (fallback)")
except Exception as e:
print(f"βœ— Failed to load any NLI model: {e}")
raise Exception("No NLI models loaded successfully")
# Normalize weights
total_weight = sum(m['weight'] for m in self.models)
for model in self.models:
model['weight'] /= total_weight
self._initialized = True
print(f"βœ“ Successfully loaded {len(self.models)} NLI model(s)")
except Exception as e:
print(f"Error loading NLI models: {e}")
self.models = []
self._initialized = False
def classify(self, claim, evidence):
"""Classify relationship between claim and evidence using ensemble"""
if not self.models:
return {
'label': 'NEUTRAL',
'confidence': 0.5,
'model_votes': {}
}
try:
results = []
model_votes = {}
for model_info in self.models:
try:
pipeline_obj = model_info['pipeline']
model_name = model_info['name']
# Handle different pipeline types
if 'bart' in model_name:
result = pipeline_obj(
evidence,
candidate_labels=["entailment", "contradiction", "neutral"],
hypothesis_template="This example is {}."
)
label = result['labels'][0]
confidence = result['scores'][0]
else:
# Standard NLI: premise [SEP] hypothesis
input_text = f"{evidence} [SEP] {claim}"
result = pipeline_obj(input_text)[0]
label = result['label']
confidence = result['score']
# Map labels
label_mapping = {
'ENTAILMENT': 'ENTAILMENT',
'CONTRADICTION': 'CONTRADICTION',
'NEUTRAL': 'NEUTRAL',
'entailment': 'ENTAILMENT',
'contradiction': 'CONTRADICTION',
'neutral': 'NEUTRAL',
'LABEL_0': 'CONTRADICTION',
'LABEL_1': 'NEUTRAL',
'LABEL_2': 'ENTAILMENT'
}
mapped_label = label_mapping.get(label, 'NEUTRAL')
results.append({
'label': mapped_label,
'confidence': confidence,
'weight': model_info['weight']
})
model_votes[model_name] = mapped_label
except Exception as e:
print(f"Error with model {model_info['name']}: {e}")
continue
if not results:
return {
'label': 'NEUTRAL',
'confidence': 0.5,
'model_votes': {}
}
# Weighted voting
weighted_scores = {
'ENTAILMENT': 0.0,
'CONTRADICTION': 0.0,
'NEUTRAL': 0.0
}
for result in results:
weighted_scores[result['label']] += result['confidence'] * result['weight']
# Get final label and confidence
final_label = max(weighted_scores, key=weighted_scores.get)
total_score = sum(weighted_scores.values())
final_confidence = weighted_scores[final_label] / total_score if total_score > 0 else 0.5
return {
'label': final_label,
'confidence': final_confidence,
'model_votes': model_votes,
'weighted_scores': weighted_scores
}
except Exception as e:
print(f"NLI classification error: {e}")
return {
'label': 'NEUTRAL',
'confidence': 0.5,
'model_votes': {}
}