Spaces:
Sleeping
Sleeping
Update predictor.py
Browse files- predictor.py +14 -87
predictor.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
import json
|
| 2 |
import re
|
| 3 |
-
import os
|
| 4 |
-
import hashlib
|
| 5 |
import onnxruntime as ort
|
| 6 |
import numpy as np
|
| 7 |
from typing import List, Dict, Set, Optional
|
|
@@ -24,17 +22,6 @@ class SentenceExtractor:
|
|
| 24 |
word_score_plus_threshold: int = 1,
|
| 25 |
word_score_minus_threshold: int = -1,
|
| 26 |
):
|
| 27 |
-
# 统一以文件所在目录为根,避免工作目录不同导致找不到资源
|
| 28 |
-
self.base_dir = os.path.dirname(os.path.abspath(__file__))
|
| 29 |
-
self.tokenizer_dir = self.base_dir
|
| 30 |
-
|
| 31 |
-
# 允许传相对路径:自动转绝对
|
| 32 |
-
if not os.path.isabs(model_path):
|
| 33 |
-
model_path = os.path.join(self.base_dir, model_path)
|
| 34 |
-
|
| 35 |
-
if not os.path.isabs(eval_keywords_path):
|
| 36 |
-
eval_keywords_path = os.path.join(self.base_dir, eval_keywords_path)
|
| 37 |
-
|
| 38 |
self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
|
| 39 |
self.all_keywords = self._extract_all_keywords()
|
| 40 |
|
|
@@ -50,43 +37,21 @@ class SentenceExtractor:
|
|
| 50 |
self.word_score_plus_threshold = int(word_score_plus_threshold)
|
| 51 |
self.word_score_minus_threshold = int(word_score_minus_threshold)
|
| 52 |
try:
|
| 53 |
-
|
| 54 |
-
self.ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
| 55 |
self.input_name = self.ort_session.get_inputs()[0].name
|
| 56 |
self.output_name = self.ort_session.get_outputs()[0].name
|
| 57 |
print("ONNX 模型加载成功")
|
| 58 |
-
self.model_loaded: bool = True
|
| 59 |
except Exception as e:
|
| 60 |
print(f"ONNX 模型加载失败: {e}")
|
| 61 |
self.ort_session = None
|
| 62 |
-
self.model_loaded = False
|
| 63 |
-
|
| 64 |
-
# 记录模型文件信息,便于排查“用错模型”问题
|
| 65 |
-
try:
|
| 66 |
-
self.model_path_abs: Optional[str] = os.path.abspath(model_path)
|
| 67 |
-
self.model_sha256: Optional[str] = None
|
| 68 |
-
if os.path.exists(model_path):
|
| 69 |
-
sha = hashlib.sha256()
|
| 70 |
-
with open(model_path, 'rb') as f:
|
| 71 |
-
for chunk in iter(lambda: f.read(8192), b''):
|
| 72 |
-
sha.update(chunk)
|
| 73 |
-
self.model_sha256 = sha.hexdigest()
|
| 74 |
-
except Exception:
|
| 75 |
-
self.model_path_abs = None
|
| 76 |
-
self.model_sha256 = None
|
| 77 |
|
| 78 |
def _preprocess_text(self, text: str) -> np.ndarray:
|
| 79 |
try:
|
| 80 |
from transformers import AutoTokenizer
|
| 81 |
-
# 1) 优先从与脚本同目录加载本地 tokenizer(部署一起带上 tokenizer.json 等文件)
|
| 82 |
try:
|
| 83 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 84 |
except Exception:
|
| 85 |
-
|
| 86 |
-
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir)
|
| 87 |
-
except Exception:
|
| 88 |
-
# 2) 兜底:在线模型(需要外网)
|
| 89 |
-
tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256")
|
| 90 |
inputs = tokenizer(
|
| 91 |
text,
|
| 92 |
truncation=True,
|
|
@@ -103,16 +68,15 @@ class SentenceExtractor:
|
|
| 103 |
features[0, i] = (ord(ch) % 256) / 255.0
|
| 104 |
return features
|
| 105 |
|
| 106 |
-
def _predict_grade_with_model(self, text: str) ->
|
| 107 |
try:
|
| 108 |
if not self.ort_session:
|
| 109 |
word_score = self._calculate_word_scores(text)["total_score"]
|
| 110 |
-
grade = "C"
|
| 111 |
if word_score > 1:
|
| 112 |
-
|
| 113 |
if word_score < -1:
|
| 114 |
-
|
| 115 |
-
return
|
| 116 |
|
| 117 |
inputs = self._preprocess_text(text)
|
| 118 |
|
|
@@ -142,29 +106,15 @@ class SentenceExtractor:
|
|
| 142 |
predictions = outputs[0]
|
| 143 |
grade_index = int(np.argmax(predictions))
|
| 144 |
grades = ['A', 'B', 'C', 'D', 'E']
|
| 145 |
-
|
| 146 |
-
return {
|
| 147 |
-
"grade": grades[grade_index],
|
| 148 |
-
"source": "model",
|
| 149 |
-
"prob": float(probs[grade_index]),
|
| 150 |
-
"probs": probs,
|
| 151 |
-
"logits": predictions[0].tolist(),
|
| 152 |
-
}
|
| 153 |
except Exception as e:
|
| 154 |
print(f"模型预测出错: {e}")
|
| 155 |
word_score = self._calculate_word_scores(text)["total_score"]
|
| 156 |
-
grade = "C"
|
| 157 |
if word_score > 1:
|
| 158 |
-
|
| 159 |
if word_score < -1:
|
| 160 |
-
|
| 161 |
-
return
|
| 162 |
-
|
| 163 |
-
@staticmethod
|
| 164 |
-
def _softmax(x: np.ndarray) -> np.ndarray:
|
| 165 |
-
x = x - np.max(x, axis=-1, keepdims=True)
|
| 166 |
-
exp_x = np.exp(x)
|
| 167 |
-
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
|
| 168 |
|
| 169 |
def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
|
| 170 |
try:
|
|
@@ -365,17 +315,9 @@ class SentenceExtractor:
|
|
| 365 |
scored_sentences = []
|
| 366 |
total_sentence_score = 0
|
| 367 |
for sentence in relevant_sentences:
|
| 368 |
-
|
| 369 |
-
grade = info.get("grade", "C")
|
| 370 |
score = score_map.get(grade, 3)
|
| 371 |
-
|
| 372 |
-
scored_sentences.append({
|
| 373 |
-
"sentence": sentence,
|
| 374 |
-
"grade": grade,
|
| 375 |
-
"source": info.get("source", "unknown"),
|
| 376 |
-
"prob": info.get("prob"),
|
| 377 |
-
"word_score_total": info.get("word_score_total"),
|
| 378 |
-
})
|
| 379 |
total_sentence_score += score
|
| 380 |
|
| 381 |
comprehensive_grade = "C"
|
|
@@ -404,19 +346,4 @@ class SentenceExtractor:
|
|
| 404 |
"neutral_word_count": word_scores["neutral_count"],
|
| 405 |
"scored_sentences": scored_sentences,
|
| 406 |
"count": len(relevant_sentences),
|
| 407 |
-
|
| 408 |
-
"debug": {
|
| 409 |
-
"model_loaded": getattr(self, "model_loaded", False),
|
| 410 |
-
"model_path_abs": getattr(self, "model_path_abs", None),
|
| 411 |
-
"model_sha256": getattr(self, "model_sha256", None),
|
| 412 |
-
"aggregation_mode": self.aggregation_mode,
|
| 413 |
-
"min_sentence_char_len": self.min_sentence_char_len,
|
| 414 |
-
"merge_leading_punct": self.merge_leading_punct,
|
| 415 |
-
"word_score_plus_threshold": self.word_score_plus_threshold,
|
| 416 |
-
"word_score_minus_threshold": self.word_score_minus_threshold,
|
| 417 |
-
"relevant_sentences": relevant_sentences,
|
| 418 |
-
"word_score_total": word_scores["total_score"],
|
| 419 |
-
}
|
| 420 |
-
}
|
| 421 |
-
|
| 422 |
-
|
|
|
|
| 1 |
import json
|
| 2 |
import re
|
|
|
|
|
|
|
| 3 |
import onnxruntime as ort
|
| 4 |
import numpy as np
|
| 5 |
from typing import List, Dict, Set, Optional
|
|
|
|
| 22 |
word_score_plus_threshold: int = 1,
|
| 23 |
word_score_minus_threshold: int = -1,
|
| 24 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
|
| 26 |
self.all_keywords = self._extract_all_keywords()
|
| 27 |
|
|
|
|
| 37 |
self.word_score_plus_threshold = int(word_score_plus_threshold)
|
| 38 |
self.word_score_minus_threshold = int(word_score_minus_threshold)
|
| 39 |
try:
|
| 40 |
+
self.ort_session = ort.InferenceSession(model_path)
|
|
|
|
| 41 |
self.input_name = self.ort_session.get_inputs()[0].name
|
| 42 |
self.output_name = self.ort_session.get_outputs()[0].name
|
| 43 |
print("ONNX 模型加载成功")
|
|
|
|
| 44 |
except Exception as e:
|
| 45 |
print(f"ONNX 模型加载失败: {e}")
|
| 46 |
self.ort_session = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def _preprocess_text(self, text: str) -> np.ndarray:
|
| 49 |
try:
|
| 50 |
from transformers import AutoTokenizer
|
|
|
|
| 51 |
try:
|
| 52 |
+
tokenizer = AutoTokenizer.from_pretrained(".", local_files_only=True)
|
| 53 |
except Exception:
|
| 54 |
+
tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
inputs = tokenizer(
|
| 56 |
text,
|
| 57 |
truncation=True,
|
|
|
|
| 68 |
features[0, i] = (ord(ch) % 256) / 255.0
|
| 69 |
return features
|
| 70 |
|
| 71 |
+
def _predict_grade_with_model(self, text: str) -> str:
|
| 72 |
try:
|
| 73 |
if not self.ort_session:
|
| 74 |
word_score = self._calculate_word_scores(text)["total_score"]
|
|
|
|
| 75 |
if word_score > 1:
|
| 76 |
+
return "B"
|
| 77 |
if word_score < -1:
|
| 78 |
+
return "D"
|
| 79 |
+
return "C"
|
| 80 |
|
| 81 |
inputs = self._preprocess_text(text)
|
| 82 |
|
|
|
|
| 106 |
predictions = outputs[0]
|
| 107 |
grade_index = int(np.argmax(predictions))
|
| 108 |
grades = ['A', 'B', 'C', 'D', 'E']
|
| 109 |
+
return grades[grade_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
except Exception as e:
|
| 111 |
print(f"模型预测出错: {e}")
|
| 112 |
word_score = self._calculate_word_scores(text)["total_score"]
|
|
|
|
| 113 |
if word_score > 1:
|
| 114 |
+
return "B"
|
| 115 |
if word_score < -1:
|
| 116 |
+
return "D"
|
| 117 |
+
return "C"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
|
| 120 |
try:
|
|
|
|
| 315 |
scored_sentences = []
|
| 316 |
total_sentence_score = 0
|
| 317 |
for sentence in relevant_sentences:
|
| 318 |
+
grade = self._predict_grade_with_model(sentence)
|
|
|
|
| 319 |
score = score_map.get(grade, 3)
|
| 320 |
+
scored_sentences.append({"sentence": sentence, "grade": grade})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
total_sentence_score += score
|
| 322 |
|
| 323 |
comprehensive_grade = "C"
|
|
|
|
| 346 |
"neutral_word_count": word_scores["neutral_count"],
|
| 347 |
"scored_sentences": scored_sentences,
|
| 348 |
"count": len(relevant_sentences),
|
| 349 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|