Wind-xixi commited on
Commit
4dc1d85
·
verified ·
1 Parent(s): 52940f1

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +14 -87
predictor.py CHANGED
@@ -1,7 +1,5 @@
1
  import json
2
  import re
3
- import os
4
- import hashlib
5
  import onnxruntime as ort
6
  import numpy as np
7
  from typing import List, Dict, Set, Optional
@@ -24,17 +22,6 @@ class SentenceExtractor:
24
  word_score_plus_threshold: int = 1,
25
  word_score_minus_threshold: int = -1,
26
  ):
27
- # 统一以文件所在目录为根,避免工作目录不同导致找不到资源
28
- self.base_dir = os.path.dirname(os.path.abspath(__file__))
29
- self.tokenizer_dir = self.base_dir
30
-
31
- # 允许传相对路径:自动转绝对
32
- if not os.path.isabs(model_path):
33
- model_path = os.path.join(self.base_dir, model_path)
34
-
35
- if not os.path.isabs(eval_keywords_path):
36
- eval_keywords_path = os.path.join(self.base_dir, eval_keywords_path)
37
-
38
  self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
39
  self.all_keywords = self._extract_all_keywords()
40
 
@@ -50,43 +37,21 @@ class SentenceExtractor:
50
  self.word_score_plus_threshold = int(word_score_plus_threshold)
51
  self.word_score_minus_threshold = int(word_score_minus_threshold)
52
  try:
53
- # 强制使用 CPU provider,避免某些环境下选择到不可用的 GPU provider 导致加载失败
54
- self.ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
55
  self.input_name = self.ort_session.get_inputs()[0].name
56
  self.output_name = self.ort_session.get_outputs()[0].name
57
  print("ONNX 模型加载成功")
58
- self.model_loaded: bool = True
59
  except Exception as e:
60
  print(f"ONNX 模型加载失败: {e}")
61
  self.ort_session = None
62
- self.model_loaded = False
63
-
64
- # 记录模型文件信息,便于排查“用错模型”问题
65
- try:
66
- self.model_path_abs: Optional[str] = os.path.abspath(model_path)
67
- self.model_sha256: Optional[str] = None
68
- if os.path.exists(model_path):
69
- sha = hashlib.sha256()
70
- with open(model_path, 'rb') as f:
71
- for chunk in iter(lambda: f.read(8192), b''):
72
- sha.update(chunk)
73
- self.model_sha256 = sha.hexdigest()
74
- except Exception:
75
- self.model_path_abs = None
76
- self.model_sha256 = None
77
 
78
  def _preprocess_text(self, text: str) -> np.ndarray:
79
  try:
80
  from transformers import AutoTokenizer
81
- # 1) 优先从与脚本同目录加载本地 tokenizer(部署一起带上 tokenizer.json 等文件)
82
  try:
83
- tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir, local_files_only=True)
84
  except Exception:
85
- try:
86
- tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir)
87
- except Exception:
88
- # 2) 兜底:在线模型(需要外网)
89
- tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256")
90
  inputs = tokenizer(
91
  text,
92
  truncation=True,
@@ -103,16 +68,15 @@ class SentenceExtractor:
103
  features[0, i] = (ord(ch) % 256) / 255.0
104
  return features
105
 
106
- def _predict_grade_with_model(self, text: str) -> Dict[str, any]:
107
  try:
108
  if not self.ort_session:
109
  word_score = self._calculate_word_scores(text)["total_score"]
110
- grade = "C"
111
  if word_score > 1:
112
- grade = "B"
113
  if word_score < -1:
114
- grade = "D"
115
- return {"grade": grade, "source": "rule", "word_score_total": word_score}
116
 
117
  inputs = self._preprocess_text(text)
118
 
@@ -142,29 +106,15 @@ class SentenceExtractor:
142
  predictions = outputs[0]
143
  grade_index = int(np.argmax(predictions))
144
  grades = ['A', 'B', 'C', 'D', 'E']
145
- probs = self._softmax(predictions)[0].tolist()
146
- return {
147
- "grade": grades[grade_index],
148
- "source": "model",
149
- "prob": float(probs[grade_index]),
150
- "probs": probs,
151
- "logits": predictions[0].tolist(),
152
- }
153
  except Exception as e:
154
  print(f"模型预测出错: {e}")
155
  word_score = self._calculate_word_scores(text)["total_score"]
156
- grade = "C"
157
  if word_score > 1:
158
- grade = "B"
159
  if word_score < -1:
160
- grade = "D"
161
- return {"grade": grade, "source": "rule", "word_score_total": word_score}
162
-
163
- @staticmethod
164
- def _softmax(x: np.ndarray) -> np.ndarray:
165
- x = x - np.max(x, axis=-1, keepdims=True)
166
- exp_x = np.exp(x)
167
- return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
168
 
169
  def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
170
  try:
@@ -365,17 +315,9 @@ class SentenceExtractor:
365
  scored_sentences = []
366
  total_sentence_score = 0
367
  for sentence in relevant_sentences:
368
- info = self._predict_grade_with_model(sentence)
369
- grade = info.get("grade", "C")
370
  score = score_map.get(grade, 3)
371
- # 附带调试信息
372
- scored_sentences.append({
373
- "sentence": sentence,
374
- "grade": grade,
375
- "source": info.get("source", "unknown"),
376
- "prob": info.get("prob"),
377
- "word_score_total": info.get("word_score_total"),
378
- })
379
  total_sentence_score += score
380
 
381
  comprehensive_grade = "C"
@@ -404,19 +346,4 @@ class SentenceExtractor:
404
  "neutral_word_count": word_scores["neutral_count"],
405
  "scored_sentences": scored_sentences,
406
  "count": len(relevant_sentences),
407
- # 调试字段
408
- "debug": {
409
- "model_loaded": getattr(self, "model_loaded", False),
410
- "model_path_abs": getattr(self, "model_path_abs", None),
411
- "model_sha256": getattr(self, "model_sha256", None),
412
- "aggregation_mode": self.aggregation_mode,
413
- "min_sentence_char_len": self.min_sentence_char_len,
414
- "merge_leading_punct": self.merge_leading_punct,
415
- "word_score_plus_threshold": self.word_score_plus_threshold,
416
- "word_score_minus_threshold": self.word_score_minus_threshold,
417
- "relevant_sentences": relevant_sentences,
418
- "word_score_total": word_scores["total_score"],
419
- }
420
- }
421
-
422
-
 
1
  import json
2
  import re
 
 
3
  import onnxruntime as ort
4
  import numpy as np
5
  from typing import List, Dict, Set, Optional
 
22
  word_score_plus_threshold: int = 1,
23
  word_score_minus_threshold: int = -1,
24
  ):
 
 
 
 
 
 
 
 
 
 
 
25
  self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
26
  self.all_keywords = self._extract_all_keywords()
27
 
 
37
  self.word_score_plus_threshold = int(word_score_plus_threshold)
38
  self.word_score_minus_threshold = int(word_score_minus_threshold)
39
  try:
40
+ self.ort_session = ort.InferenceSession(model_path)
 
41
  self.input_name = self.ort_session.get_inputs()[0].name
42
  self.output_name = self.ort_session.get_outputs()[0].name
43
  print("ONNX 模型加载成功")
 
44
  except Exception as e:
45
  print(f"ONNX 模型加载失败: {e}")
46
  self.ort_session = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def _preprocess_text(self, text: str) -> np.ndarray:
49
  try:
50
  from transformers import AutoTokenizer
 
51
  try:
52
+ tokenizer = AutoTokenizer.from_pretrained(".", local_files_only=True)
53
  except Exception:
54
+ tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256")
 
 
 
 
55
  inputs = tokenizer(
56
  text,
57
  truncation=True,
 
68
  features[0, i] = (ord(ch) % 256) / 255.0
69
  return features
70
 
71
+ def _predict_grade_with_model(self, text: str) -> str:
72
  try:
73
  if not self.ort_session:
74
  word_score = self._calculate_word_scores(text)["total_score"]
 
75
  if word_score > 1:
76
+ return "B"
77
  if word_score < -1:
78
+ return "D"
79
+ return "C"
80
 
81
  inputs = self._preprocess_text(text)
82
 
 
106
  predictions = outputs[0]
107
  grade_index = int(np.argmax(predictions))
108
  grades = ['A', 'B', 'C', 'D', 'E']
109
+ return grades[grade_index]
 
 
 
 
 
 
 
110
  except Exception as e:
111
  print(f"模型预测出错: {e}")
112
  word_score = self._calculate_word_scores(text)["total_score"]
 
113
  if word_score > 1:
114
+ return "B"
115
  if word_score < -1:
116
+ return "D"
117
+ return "C"
 
 
 
 
 
 
118
 
119
  def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
120
  try:
 
315
  scored_sentences = []
316
  total_sentence_score = 0
317
  for sentence in relevant_sentences:
318
+ grade = self._predict_grade_with_model(sentence)
 
319
  score = score_map.get(grade, 3)
320
+ scored_sentences.append({"sentence": sentence, "grade": grade})
 
 
 
 
 
 
 
321
  total_sentence_score += score
322
 
323
  comprehensive_grade = "C"
 
346
  "neutral_word_count": word_scores["neutral_count"],
347
  "scored_sentences": scored_sentences,
348
  "count": len(relevant_sentences),
349
+ }