Spaces:
Sleeping
Sleeping
Update predictor.py
Browse files- predictor.py +111 -21
predictor.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import json
|
| 2 |
import re
|
|
|
|
|
|
|
| 3 |
import onnxruntime as ort
|
| 4 |
import numpy as np
|
| 5 |
from typing import List, Dict, Set, Optional
|
|
@@ -22,6 +24,17 @@ class SentenceExtractor:
|
|
| 22 |
word_score_plus_threshold: int = 1,
|
| 23 |
word_score_minus_threshold: int = -1,
|
| 24 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
|
| 26 |
self.all_keywords = self._extract_all_keywords()
|
| 27 |
|
|
@@ -36,22 +49,51 @@ class SentenceExtractor:
|
|
| 36 |
self.aggregation_mode = "max"
|
| 37 |
self.word_score_plus_threshold = int(word_score_plus_threshold)
|
| 38 |
self.word_score_minus_threshold = int(word_score_minus_threshold)
|
|
|
|
|
|
|
|
|
|
| 39 |
try:
|
| 40 |
-
|
|
|
|
| 41 |
self.input_name = self.ort_session.get_inputs()[0].name
|
| 42 |
self.output_name = self.ort_session.get_outputs()[0].name
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
print("ONNX 模型加载成功")
|
|
|
|
| 44 |
except Exception as e:
|
| 45 |
print(f"ONNX 模型加载失败: {e}")
|
| 46 |
self.ort_session = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def _preprocess_text(self, text: str) -> np.ndarray:
|
| 49 |
try:
|
| 50 |
from transformers import AutoTokenizer
|
|
|
|
| 51 |
try:
|
| 52 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 53 |
except Exception:
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
inputs = tokenizer(
|
| 56 |
text,
|
| 57 |
truncation=True,
|
|
@@ -59,24 +101,25 @@ class SentenceExtractor:
|
|
| 59 |
max_length=512,
|
| 60 |
return_tensors='np'
|
| 61 |
)
|
|
|
|
|
|
|
| 62 |
return inputs
|
| 63 |
except Exception as e:
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def _predict_grade_with_model(self, text: str) -> str:
|
| 72 |
try:
|
| 73 |
if not self.ort_session:
|
| 74 |
word_score = self._calculate_word_scores(text)["total_score"]
|
|
|
|
| 75 |
if word_score > 1:
|
| 76 |
-
|
| 77 |
if word_score < -1:
|
| 78 |
-
|
| 79 |
-
return "
|
| 80 |
|
| 81 |
inputs = self._preprocess_text(text)
|
| 82 |
|
|
@@ -106,15 +149,36 @@ class SentenceExtractor:
|
|
| 106 |
predictions = outputs[0]
|
| 107 |
grade_index = int(np.argmax(predictions))
|
| 108 |
grades = ['A', 'B', 'C', 'D', 'E']
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
except Exception as e:
|
| 111 |
print(f"模型预测出错: {e}")
|
| 112 |
word_score = self._calculate_word_scores(text)["total_score"]
|
|
|
|
| 113 |
if word_score > 1:
|
| 114 |
-
|
| 115 |
if word_score < -1:
|
| 116 |
-
|
| 117 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
|
| 120 |
try:
|
|
@@ -315,9 +379,17 @@ class SentenceExtractor:
|
|
| 315 |
scored_sentences = []
|
| 316 |
total_sentence_score = 0
|
| 317 |
for sentence in relevant_sentences:
|
| 318 |
-
|
|
|
|
| 319 |
score = score_map.get(grade, 3)
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
total_sentence_score += score
|
| 322 |
|
| 323 |
comprehensive_grade = "C"
|
|
@@ -346,4 +418,22 @@ class SentenceExtractor:
|
|
| 346 |
"neutral_word_count": word_scores["neutral_count"],
|
| 347 |
"scored_sentences": scored_sentences,
|
| 348 |
"count": len(relevant_sentences),
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import re
|
| 3 |
+
import os
|
| 4 |
+
import hashlib
|
| 5 |
import onnxruntime as ort
|
| 6 |
import numpy as np
|
| 7 |
from typing import List, Dict, Set, Optional
|
|
|
|
| 24 |
word_score_plus_threshold: int = 1,
|
| 25 |
word_score_minus_threshold: int = -1,
|
| 26 |
):
|
| 27 |
+
# 统一以文件所在目录为根,避免工作目录不同导致找不到资源
|
| 28 |
+
self.base_dir = os.path.dirname(os.path.abspath(__file__))
|
| 29 |
+
self.tokenizer_dir = self.base_dir
|
| 30 |
+
|
| 31 |
+
# 允许传相对路径:自动转绝对
|
| 32 |
+
if not os.path.isabs(model_path):
|
| 33 |
+
model_path = os.path.join(self.base_dir, model_path)
|
| 34 |
+
|
| 35 |
+
if not os.path.isabs(eval_keywords_path):
|
| 36 |
+
eval_keywords_path = os.path.join(self.base_dir, eval_keywords_path)
|
| 37 |
+
|
| 38 |
self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
|
| 39 |
self.all_keywords = self._extract_all_keywords()
|
| 40 |
|
|
|
|
| 49 |
self.aggregation_mode = "max"
|
| 50 |
self.word_score_plus_threshold = int(word_score_plus_threshold)
|
| 51 |
self.word_score_minus_threshold = int(word_score_minus_threshold)
|
| 52 |
+
self.providers: Optional[List[str]] = None
|
| 53 |
+
self.tokenizer_loaded: bool = False
|
| 54 |
+
self.last_tokenizer_error: Optional[str] = None
|
| 55 |
try:
|
| 56 |
+
# 强制使用 CPU provider,避免某些环境下选择到不可用的 GPU provider 导致加载失败
|
| 57 |
+
self.ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
| 58 |
self.input_name = self.ort_session.get_inputs()[0].name
|
| 59 |
self.output_name = self.ort_session.get_outputs()[0].name
|
| 60 |
+
try:
|
| 61 |
+
self.providers = self.ort_session.get_providers()
|
| 62 |
+
except Exception:
|
| 63 |
+
self.providers = None
|
| 64 |
print("ONNX 模型加载成功")
|
| 65 |
+
self.model_loaded: bool = True
|
| 66 |
except Exception as e:
|
| 67 |
print(f"ONNX 模型加载失败: {e}")
|
| 68 |
self.ort_session = None
|
| 69 |
+
self.model_loaded = False
|
| 70 |
+
|
| 71 |
+
# 记录模型文件信息,便于排查“用错模型”问题
|
| 72 |
+
try:
|
| 73 |
+
self.model_path_abs: Optional[str] = os.path.abspath(model_path)
|
| 74 |
+
self.model_sha256: Optional[str] = None
|
| 75 |
+
if os.path.exists(model_path):
|
| 76 |
+
sha = hashlib.sha256()
|
| 77 |
+
with open(model_path, 'rb') as f:
|
| 78 |
+
for chunk in iter(lambda: f.read(8192), b''):
|
| 79 |
+
sha.update(chunk)
|
| 80 |
+
self.model_sha256 = sha.hexdigest()
|
| 81 |
+
except Exception:
|
| 82 |
+
self.model_path_abs = None
|
| 83 |
+
self.model_sha256 = None
|
| 84 |
|
| 85 |
def _preprocess_text(self, text: str) -> np.ndarray:
|
| 86 |
try:
|
| 87 |
from transformers import AutoTokenizer
|
| 88 |
+
# 1) 优先从与脚本同目录加载本地 tokenizer(部署一起带上 tokenizer.json 等文件)
|
| 89 |
try:
|
| 90 |
+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir, local_files_only=True)
|
| 91 |
except Exception:
|
| 92 |
+
try:
|
| 93 |
+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir)
|
| 94 |
+
except Exception:
|
| 95 |
+
# 2) 兜底:在线模型(需要外网)
|
| 96 |
+
tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256")
|
| 97 |
inputs = tokenizer(
|
| 98 |
text,
|
| 99 |
truncation=True,
|
|
|
|
| 101 |
max_length=512,
|
| 102 |
return_tensors='np'
|
| 103 |
)
|
| 104 |
+
self.tokenizer_loaded = True
|
| 105 |
+
self.last_tokenizer_error = None
|
| 106 |
return inputs
|
| 107 |
except Exception as e:
|
| 108 |
+
self.tokenizer_loaded = False
|
| 109 |
+
self.last_tokenizer_error = str(e)
|
| 110 |
+
# 继续抛出异常,由上层捕获并回退,同时记录原因
|
| 111 |
+
raise
|
| 112 |
+
|
| 113 |
+
def _predict_grade_with_model(self, text: str) -> Dict[str, any]:
|
|
|
|
|
|
|
| 114 |
try:
|
| 115 |
if not self.ort_session:
|
| 116 |
word_score = self._calculate_word_scores(text)["total_score"]
|
| 117 |
+
grade = "C"
|
| 118 |
if word_score > 1:
|
| 119 |
+
grade = "B"
|
| 120 |
if word_score < -1:
|
| 121 |
+
grade = "D"
|
| 122 |
+
return {"grade": grade, "source": "rule", "word_score_total": word_score}
|
| 123 |
|
| 124 |
inputs = self._preprocess_text(text)
|
| 125 |
|
|
|
|
| 149 |
predictions = outputs[0]
|
| 150 |
grade_index = int(np.argmax(predictions))
|
| 151 |
grades = ['A', 'B', 'C', 'D', 'E']
|
| 152 |
+
probs = self._softmax(predictions)[0].tolist()
|
| 153 |
+
return {
|
| 154 |
+
"grade": grades[grade_index],
|
| 155 |
+
"source": "model",
|
| 156 |
+
"prob": float(probs[grade_index]),
|
| 157 |
+
"probs": probs,
|
| 158 |
+
"logits": predictions[0].tolist(),
|
| 159 |
+
}
|
| 160 |
except Exception as e:
|
| 161 |
print(f"模型预测出错: {e}")
|
| 162 |
word_score = self._calculate_word_scores(text)["total_score"]
|
| 163 |
+
grade = "C"
|
| 164 |
if word_score > 1:
|
| 165 |
+
grade = "B"
|
| 166 |
if word_score < -1:
|
| 167 |
+
grade = "D"
|
| 168 |
+
return {
|
| 169 |
+
"grade": grade,
|
| 170 |
+
"source": "rule",
|
| 171 |
+
"word_score_total": word_score,
|
| 172 |
+
"reason": str(e),
|
| 173 |
+
"tokenizer_loaded": self.tokenizer_loaded,
|
| 174 |
+
"last_tokenizer_error": self.last_tokenizer_error,
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def _softmax(x: np.ndarray) -> np.ndarray:
|
| 179 |
+
x = x - np.max(x, axis=-1, keepdims=True)
|
| 180 |
+
exp_x = np.exp(x)
|
| 181 |
+
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
|
| 182 |
|
| 183 |
def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
|
| 184 |
try:
|
|
|
|
| 379 |
scored_sentences = []
|
| 380 |
total_sentence_score = 0
|
| 381 |
for sentence in relevant_sentences:
|
| 382 |
+
info = self._predict_grade_with_model(sentence)
|
| 383 |
+
grade = info.get("grade", "C")
|
| 384 |
score = score_map.get(grade, 3)
|
| 385 |
+
# 附带调试信息
|
| 386 |
+
scored_sentences.append({
|
| 387 |
+
"sentence": sentence,
|
| 388 |
+
"grade": grade,
|
| 389 |
+
"source": info.get("source", "unknown"),
|
| 390 |
+
"prob": info.get("prob"),
|
| 391 |
+
"word_score_total": info.get("word_score_total"),
|
| 392 |
+
})
|
| 393 |
total_sentence_score += score
|
| 394 |
|
| 395 |
comprehensive_grade = "C"
|
|
|
|
| 418 |
"neutral_word_count": word_scores["neutral_count"],
|
| 419 |
"scored_sentences": scored_sentences,
|
| 420 |
"count": len(relevant_sentences),
|
| 421 |
+
# 调试字段
|
| 422 |
+
"debug": {
|
| 423 |
+
"model_loaded": getattr(self, "model_loaded", False),
|
| 424 |
+
"model_path_abs": getattr(self, "model_path_abs", None),
|
| 425 |
+
"model_sha256": getattr(self, "model_sha256", None),
|
| 426 |
+
"providers": self.providers,
|
| 427 |
+
"tokenizer_loaded": self.tokenizer_loaded,
|
| 428 |
+
"last_tokenizer_error": self.last_tokenizer_error,
|
| 429 |
+
"aggregation_mode": self.aggregation_mode,
|
| 430 |
+
"min_sentence_char_len": self.min_sentence_char_len,
|
| 431 |
+
"merge_leading_punct": self.merge_leading_punct,
|
| 432 |
+
"word_score_plus_threshold": self.word_score_plus_threshold,
|
| 433 |
+
"word_score_minus_threshold": self.word_score_minus_threshold,
|
| 434 |
+
"relevant_sentences": relevant_sentences,
|
| 435 |
+
"word_score_total": word_scores["total_score"],
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
|