Wind-xixi commited on
Commit
bdb326c
·
verified ·
1 Parent(s): 4dc1d85

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +111 -21
predictor.py CHANGED
@@ -1,5 +1,7 @@
1
  import json
2
  import re
 
 
3
  import onnxruntime as ort
4
  import numpy as np
5
  from typing import List, Dict, Set, Optional
@@ -22,6 +24,17 @@ class SentenceExtractor:
22
  word_score_plus_threshold: int = 1,
23
  word_score_minus_threshold: int = -1,
24
  ):
 
 
 
 
 
 
 
 
 
 
 
25
  self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
26
  self.all_keywords = self._extract_all_keywords()
27
 
@@ -36,22 +49,51 @@ class SentenceExtractor:
36
  self.aggregation_mode = "max"
37
  self.word_score_plus_threshold = int(word_score_plus_threshold)
38
  self.word_score_minus_threshold = int(word_score_minus_threshold)
 
 
 
39
  try:
40
- self.ort_session = ort.InferenceSession(model_path)
 
41
  self.input_name = self.ort_session.get_inputs()[0].name
42
  self.output_name = self.ort_session.get_outputs()[0].name
 
 
 
 
43
  print("ONNX 模型加载成功")
 
44
  except Exception as e:
45
  print(f"ONNX 模型加载失败: {e}")
46
  self.ort_session = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def _preprocess_text(self, text: str) -> np.ndarray:
49
  try:
50
  from transformers import AutoTokenizer
 
51
  try:
52
- tokenizer = AutoTokenizer.from_pretrained(".", local_files_only=True)
53
  except Exception:
54
- tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256")
 
 
 
 
55
  inputs = tokenizer(
56
  text,
57
  truncation=True,
@@ -59,24 +101,25 @@ class SentenceExtractor:
59
  max_length=512,
60
  return_tensors='np'
61
  )
 
 
62
  return inputs
63
  except Exception as e:
64
- print(f"Tokenizer预处理失败: {e}")
65
- max_seq_length = 128
66
- features = np.zeros((1, max_seq_length), dtype=np.float32)
67
- for i, ch in enumerate(text[:max_seq_length]):
68
- features[0, i] = (ord(ch) % 256) / 255.0
69
- return features
70
-
71
- def _predict_grade_with_model(self, text: str) -> str:
72
  try:
73
  if not self.ort_session:
74
  word_score = self._calculate_word_scores(text)["total_score"]
 
75
  if word_score > 1:
76
- return "B"
77
  if word_score < -1:
78
- return "D"
79
- return "C"
80
 
81
  inputs = self._preprocess_text(text)
82
 
@@ -106,15 +149,36 @@ class SentenceExtractor:
106
  predictions = outputs[0]
107
  grade_index = int(np.argmax(predictions))
108
  grades = ['A', 'B', 'C', 'D', 'E']
109
- return grades[grade_index]
 
 
 
 
 
 
 
110
  except Exception as e:
111
  print(f"模型预测出错: {e}")
112
  word_score = self._calculate_word_scores(text)["total_score"]
 
113
  if word_score > 1:
114
- return "B"
115
  if word_score < -1:
116
- return "D"
117
- return "C"
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
120
  try:
@@ -315,9 +379,17 @@ class SentenceExtractor:
315
  scored_sentences = []
316
  total_sentence_score = 0
317
  for sentence in relevant_sentences:
318
- grade = self._predict_grade_with_model(sentence)
 
319
  score = score_map.get(grade, 3)
320
- scored_sentences.append({"sentence": sentence, "grade": grade})
 
 
 
 
 
 
 
321
  total_sentence_score += score
322
 
323
  comprehensive_grade = "C"
@@ -346,4 +418,22 @@ class SentenceExtractor:
346
  "neutral_word_count": word_scores["neutral_count"],
347
  "scored_sentences": scored_sentences,
348
  "count": len(relevant_sentences),
349
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import re
3
+ import os
4
+ import hashlib
5
  import onnxruntime as ort
6
  import numpy as np
7
  from typing import List, Dict, Set, Optional
 
24
  word_score_plus_threshold: int = 1,
25
  word_score_minus_threshold: int = -1,
26
  ):
27
+ # 统一以文件所在目录为根,避免工作目录不同导致找不到资源
28
+ self.base_dir = os.path.dirname(os.path.abspath(__file__))
29
+ self.tokenizer_dir = self.base_dir
30
+
31
+ # 允许传相对路径:自动转绝对
32
+ if not os.path.isabs(model_path):
33
+ model_path = os.path.join(self.base_dir, model_path)
34
+
35
+ if not os.path.isabs(eval_keywords_path):
36
+ eval_keywords_path = os.path.join(self.base_dir, eval_keywords_path)
37
+
38
  self.eval_keywords = self._load_eval_keywords(eval_keywords_path)
39
  self.all_keywords = self._extract_all_keywords()
40
 
 
49
  self.aggregation_mode = "max"
50
  self.word_score_plus_threshold = int(word_score_plus_threshold)
51
  self.word_score_minus_threshold = int(word_score_minus_threshold)
52
+ self.providers: Optional[List[str]] = None
53
+ self.tokenizer_loaded: bool = False
54
+ self.last_tokenizer_error: Optional[str] = None
55
  try:
56
+ # 强制使用 CPU provider,避免某些环境下选择到不可用的 GPU provider 导致加载失败
57
+ self.ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
58
  self.input_name = self.ort_session.get_inputs()[0].name
59
  self.output_name = self.ort_session.get_outputs()[0].name
60
+ try:
61
+ self.providers = self.ort_session.get_providers()
62
+ except Exception:
63
+ self.providers = None
64
  print("ONNX 模型加载成功")
65
+ self.model_loaded: bool = True
66
  except Exception as e:
67
  print(f"ONNX 模型加载失败: {e}")
68
  self.ort_session = None
69
+ self.model_loaded = False
70
+
71
+ # 记录模型文件信息,便于排查“用错模型”问题
72
+ try:
73
+ self.model_path_abs: Optional[str] = os.path.abspath(model_path)
74
+ self.model_sha256: Optional[str] = None
75
+ if os.path.exists(model_path):
76
+ sha = hashlib.sha256()
77
+ with open(model_path, 'rb') as f:
78
+ for chunk in iter(lambda: f.read(8192), b''):
79
+ sha.update(chunk)
80
+ self.model_sha256 = sha.hexdigest()
81
+ except Exception:
82
+ self.model_path_abs = None
83
+ self.model_sha256 = None
84
 
85
  def _preprocess_text(self, text: str) -> np.ndarray:
86
  try:
87
  from transformers import AutoTokenizer
88
+ # 1) 优先从与脚本同目录加载本地 tokenizer(部署一起带上 tokenizer.json 等文件)
89
  try:
90
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir, local_files_only=True)
91
  except Exception:
92
+ try:
93
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir)
94
+ except Exception:
95
+ # 2) 兜底:在线模型(需要外网)
96
+ tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256")
97
  inputs = tokenizer(
98
  text,
99
  truncation=True,
 
101
  max_length=512,
102
  return_tensors='np'
103
  )
104
+ self.tokenizer_loaded = True
105
+ self.last_tokenizer_error = None
106
  return inputs
107
  except Exception as e:
108
+ self.tokenizer_loaded = False
109
+ self.last_tokenizer_error = str(e)
110
+ # 继续抛出异常,由上层捕获并回退,同时记录原因
111
+ raise
112
+
113
+ def _predict_grade_with_model(self, text: str) -> Dict[str, any]:
 
 
114
  try:
115
  if not self.ort_session:
116
  word_score = self._calculate_word_scores(text)["total_score"]
117
+ grade = "C"
118
  if word_score > 1:
119
+ grade = "B"
120
  if word_score < -1:
121
+ grade = "D"
122
+ return {"grade": grade, "source": "rule", "word_score_total": word_score}
123
 
124
  inputs = self._preprocess_text(text)
125
 
 
149
  predictions = outputs[0]
150
  grade_index = int(np.argmax(predictions))
151
  grades = ['A', 'B', 'C', 'D', 'E']
152
+ probs = self._softmax(predictions)[0].tolist()
153
+ return {
154
+ "grade": grades[grade_index],
155
+ "source": "model",
156
+ "prob": float(probs[grade_index]),
157
+ "probs": probs,
158
+ "logits": predictions[0].tolist(),
159
+ }
160
  except Exception as e:
161
  print(f"模型预测出错: {e}")
162
  word_score = self._calculate_word_scores(text)["total_score"]
163
+ grade = "C"
164
  if word_score > 1:
165
+ grade = "B"
166
  if word_score < -1:
167
+ grade = "D"
168
+ return {
169
+ "grade": grade,
170
+ "source": "rule",
171
+ "word_score_total": word_score,
172
+ "reason": str(e),
173
+ "tokenizer_loaded": self.tokenizer_loaded,
174
+ "last_tokenizer_error": self.last_tokenizer_error,
175
+ }
176
+
177
+ @staticmethod
178
+ def _softmax(x: np.ndarray) -> np.ndarray:
179
+ x = x - np.max(x, axis=-1, keepdims=True)
180
+ exp_x = np.exp(x)
181
+ return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
182
 
183
  def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]:
184
  try:
 
379
  scored_sentences = []
380
  total_sentence_score = 0
381
  for sentence in relevant_sentences:
382
+ info = self._predict_grade_with_model(sentence)
383
+ grade = info.get("grade", "C")
384
  score = score_map.get(grade, 3)
385
+ # 附带调试信息
386
+ scored_sentences.append({
387
+ "sentence": sentence,
388
+ "grade": grade,
389
+ "source": info.get("source", "unknown"),
390
+ "prob": info.get("prob"),
391
+ "word_score_total": info.get("word_score_total"),
392
+ })
393
  total_sentence_score += score
394
 
395
  comprehensive_grade = "C"
 
418
  "neutral_word_count": word_scores["neutral_count"],
419
  "scored_sentences": scored_sentences,
420
  "count": len(relevant_sentences),
421
+ # 调试字段
422
+ "debug": {
423
+ "model_loaded": getattr(self, "model_loaded", False),
424
+ "model_path_abs": getattr(self, "model_path_abs", None),
425
+ "model_sha256": getattr(self, "model_sha256", None),
426
+ "providers": self.providers,
427
+ "tokenizer_loaded": self.tokenizer_loaded,
428
+ "last_tokenizer_error": self.last_tokenizer_error,
429
+ "aggregation_mode": self.aggregation_mode,
430
+ "min_sentence_char_len": self.min_sentence_char_len,
431
+ "merge_leading_punct": self.merge_leading_punct,
432
+ "word_score_plus_threshold": self.word_score_plus_threshold,
433
+ "word_score_minus_threshold": self.word_score_minus_threshold,
434
+ "relevant_sentences": relevant_sentences,
435
+ "word_score_total": word_scores["total_score"],
436
+ }
437
+ }
438
+
439
+