Wind-xixi commited on
Commit
4b0fcc2
·
verified ·
1 Parent(s): 9c3a940

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +158 -155
predictor.py CHANGED
@@ -1,155 +1,158 @@
1
- # predictor.py
2
-
3
- import torch
4
- import re
5
- import os
6
- import json
7
- import onnxruntime as ort
8
- from collections import defaultdict, Counter
9
- from difflib import SequenceMatcher
10
- from transformers import BertTokenizerFast
11
-
12
-
13
- class DialogueEvaluator:
14
- def __init__(self, model_dir, keywords_path):
15
- print("Initializing DialogueEvaluator...")
16
- # 加载模型和tokenizer
17
- self.tokenizer, self.model, self.id2label = self._load_model(model_dir)
18
- print("✅ Model and Tokenizer loaded.")
19
-
20
- # 从JSON文件加载关键词体系
21
- with open(keywords_path, 'r', encoding='utf-8') as f:
22
- self.academic_keywords = json.load(f)
23
- print("✅ Keywords loaded.")
24
-
25
- # 构建关键词正则表达式模式
26
- self.keyword_patterns = self._build_keyword_patterns()
27
- print("✅ Keyword patterns built.")
28
-
29
- # 场景化停用词
30
- self.scene_stopwords = r'^(嗯|啊|哦|呃|呐|哟)'
31
- print("DialogueEvaluator initialized successfully.")
32
-
33
- def _load_model(self, model_dir):
34
- tokenizer = BertTokenizerFast.from_pretrained(model_dir)
35
- sess_options = ort.SessionOptions()
36
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
37
- model_path = os.path.join(model_dir, "model_quantized.onnx")
38
-
39
- try:
40
- model = ort.InferenceSession(model_path, sess_options, providers=providers)
41
- print(f"ℹ️ ONNX Runtime using: {model.get_providers()[0]}")
42
- except Exception as e:
43
- print(f"❌ ONNX Runtime initialization failed: {str(e)}")
44
- raise
45
-
46
- with open(os.path.join(model_dir, "label_map.json"), 'r', encoding='utf-8') as f:
47
- label_map = json.load(f)
48
- id2label = {int(k): v for k, v in label_map['id2label'].items()}
49
- return tokenizer, model, id2label
50
-
51
- def _build_keyword_patterns(self):
52
- patterns = {}
53
- for scene, sentiment_dict in self.academic_keywords.items():
54
- for sentiment, keywords in sentiment_dict.items():
55
- for keyword in keywords:
56
- pattern = self._create_fuzzy_pattern(keyword)
57
- patterns[keyword] = {
58
- 'pattern': pattern,
59
- 'scene': scene,
60
- 'sentiment': sentiment
61
- }
62
- return patterns
63
-
64
- def _create_fuzzy_pattern(self, keyword):
65
- # 简单的模糊匹配,允许中间有一个任意字符
66
- if len(keyword) <= 2:
67
- return re.compile(re.escape(keyword))
68
- pattern_str = re.escape(keyword[0]) + ''.join([f"{re.escape(c)}.?" for c in keyword[1:]])
69
- return re.compile(pattern_str)
70
-
71
- def _fuzzy_match_keywords(self, sentence):
72
- matched_info = []
73
- for keyword, info in self.keyword_patterns.items():
74
- if info['pattern'].search(sentence):
75
- # 简单处理否定情况
76
- sentiment = info['sentiment']
77
- if re.search(fr'(不|没有|无|否|缺乏|不足|不够){keyword}', sentence):
78
- sentiment = 'negative'
79
-
80
- matched_info.append({
81
- 'keyword': keyword,
82
- 'scene': info['scene'],
83
- 'sentiment': sentiment
84
- })
85
- return matched_info
86
-
87
- def _clean_sentence(self, sentence):
88
- sentence = re.sub(r'[^\w\s\u4e00-\u9fff,。;:、]', '', sentence)
89
- sentence = re.sub(r'\s+', ' ', sentence).strip()
90
- sentence = re.sub(self.scene_stopwords, '', sentence)
91
- return sentence
92
-
93
- def _extract_key_sentences(self, text):
94
- sentences = re.split(r'[。!?;\n]', text)
95
- key_sentences = []
96
- for sent in sentences:
97
- if len(sent) < 5: continue # 忽略太短的句子
98
- clean_sent = self._clean_sentence(sent)
99
- if not clean_sent: continue
100
- matched_info = self._fuzzy_match_keywords(clean_sent)
101
- if matched_info:
102
- key_sentences.append({
103
- 'sentence': clean_sent,
104
- 'matched_info': matched_info,
105
- })
106
- return key_sentences
107
-
108
- def _predict_sentence(self, sentence):
109
- inputs = self.tokenizer(
110
- sentence, truncation=True, padding='max_length', max_length=128, return_tensors="np"
111
- )
112
- ort_inputs = {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask']}
113
- try:
114
- outputs = self.model.run(None, ort_inputs)
115
- logits = outputs[0]
116
- probs = torch.softmax(torch.tensor(logits), dim=1)
117
- pred_id = torch.argmax(probs).item()
118
- return {
119
- 'label': self.id2label[pred_id],
120
- 'confidence': round(torch.max(probs).item(), 4)
121
- }
122
- except Exception as e:
123
- print(f"❌ Inference failed for sentence: '{sentence}'. Error: {str(e)}")
124
- return {'label': 'ERROR', 'confidence': 0.0}
125
-
126
- def evaluate_full_text(self, text):
127
- key_sentences_info = self._extract_key_sentences(text)
128
- if not key_sentences_info:
129
- return {'status': 'no_key_sentences', 'message': '未检测到包含评价关键词的有效句子。'}
130
-
131
- processed_sentences = []
132
- for sent_info in key_sentences_info:
133
- prediction = self._predict_sentence(sent_info['sentence'])
134
- sent_info.update(prediction)
135
- processed_sentences.append(sent_info)
136
-
137
- # --- 生成统计数据 ---
138
- overall_stats = defaultdict(lambda: defaultdict(int))
139
- all_labels = [sent['label'] for sent in processed_sentences]
140
- overall_stats['total_sentences'] = len(processed_sentences)
141
- overall_stats['label_distribution'] = dict(Counter(all_labels))
142
- overall_stats['avg_confidence'] = round(
143
- sum(s['confidence'] for s in processed_sentences) / len(processed_sentences),
144
- 4) if processed_sentences else 0
145
-
146
- for sent in processed_sentences:
147
- for info in sent['matched_info']:
148
- overall_stats['scene_distribution'][info['scene']] += 1
149
- overall_stats['sentiment_distribution'][info['sentiment']] += 1
150
-
151
- return {
152
- 'status': 'success',
153
- 'overall_stats': dict(overall_stats),
154
- 'key_sentences': processed_sentences
155
- }
 
 
 
 
1
+ # predictor.py
2
+
3
+ import torch
4
+ import re
5
+ import os
6
+ import json
7
+ from pathlib import Path
8
+ import onnxruntime as ort
9
+ from collections import defaultdict, Counter
10
+ from difflib import SequenceMatcher
11
+ from transformers import BertTokenizerFast
12
+
13
+
14
+ class DialogueEvaluator:
15
+ def __init__(self, model_path, keywords_path):
16
+ print("Initializing DialogueEvaluator...")
17
+
18
+ # 加载模型和tokenizer
19
+ self.tokenizer, self.model, self.id2label = self._load_model(model_path)
20
+ print("✅ Model and Tokenizer loaded.")
21
+
22
+ # 加载关键词体系
23
+ with open(keywords_path, 'r', encoding='utf-8') as f:
24
+ self.academic_keywords = json.load(f)
25
+ print("✅ Keywords loaded.")
26
+
27
+ # 构建关键词正则
28
+ self.keyword_patterns = self._build_keyword_patterns()
29
+ print("✅ Keyword patterns built.")
30
+
31
+ # 场景化停用词
32
+ self.scene_stopwords = r'^(嗯|啊|哦|呃|呐|哟)'
33
+ print("DialogueEvaluator initialized successfully.")
34
+
35
+ def _load_model(self, model_path):
36
+ model_path = Path(model_path)
37
+ model_dir = model_path.parent
38
+
39
+ tokenizer = BertTokenizerFast.from_pretrained(model_dir)
40
+
41
+ sess_options = ort.SessionOptions()
42
+ providers = ['CPUExecutionProvider'] # Hugging Face Spaces 通常无 GPU
43
+ model = ort.InferenceSession(str(model_path), sess_options, providers=providers)
44
+ print(f"ℹ️ ONNX Runtime using: {model.get_providers()[0]}")
45
+
46
+ label_map_path = model_dir / "label_map.json"
47
+ if not label_map_path.exists():
48
+ raise FileNotFoundError(f"Missing label_map.json at: {label_map_path}")
49
+
50
+ with open(label_map_path, 'r', encoding='utf-8') as f:
51
+ label_map = json.load(f)
52
+ id2label = {int(k): v for k, v in label_map['id2label'].items()}
53
+ return tokenizer, model, id2label
54
+
55
+ def _build_keyword_patterns(self):
56
+ patterns = {}
57
+ for scene, sentiment_dict in self.academic_keywords.items():
58
+ for sentiment, keywords in sentiment_dict.items():
59
+ for keyword in keywords:
60
+ pattern = self._create_fuzzy_pattern(keyword)
61
+ patterns[keyword] = {
62
+ 'pattern': pattern,
63
+ 'scene': scene,
64
+ 'sentiment': sentiment
65
+ }
66
+ return patterns
67
+
68
+ def _create_fuzzy_pattern(self, keyword):
69
+ if len(keyword) <= 2:
70
+ return re.compile(re.escape(keyword))
71
+ pattern_str = re.escape(keyword[0]) + ''.join([f"{re.escape(c)}.?" for c in keyword[1:]])
72
+ return re.compile(pattern_str)
73
+
74
+ def _fuzzy_match_keywords(self, sentence):
75
+ matched_info = []
76
+ for keyword, info in self.keyword_patterns.items():
77
+ if info['pattern'].search(sentence):
78
+ sentiment = info['sentiment']
79
+ if re.search(fr'(不|没有|无|否|缺乏|不足|不够){keyword}', sentence):
80
+ sentiment = 'negative'
81
+
82
+ matched_info.append({
83
+ 'keyword': keyword,
84
+ 'scene': info['scene'],
85
+ 'sentiment': sentiment
86
+ })
87
+ return matched_info
88
+
89
+ def _clean_sentence(self, sentence):
90
+ sentence = re.sub(r'[^\w\s\u4e00-\u9fff,。;:、]', '', sentence)
91
+ sentence = re.sub(r'\s+', ' ', sentence).strip()
92
+ sentence = re.sub(self.scene_stopwords, '', sentence)
93
+ return sentence
94
+
95
+ def _extract_key_sentences(self, text):
96
+ sentences = re.split(r'[。!?;\n]', text)
97
+ key_sentences = []
98
+ for sent in sentences:
99
+ if len(sent) < 5:
100
+ continue
101
+ clean_sent = self._clean_sentence(sent)
102
+ if not clean_sent:
103
+ continue
104
+ matched_info = self._fuzzy_match_keywords(clean_sent)
105
+ if matched_info:
106
+ key_sentences.append({
107
+ 'sentence': clean_sent,
108
+ 'matched_info': matched_info,
109
+ })
110
+ return key_sentences
111
+
112
+ def _predict_sentence(self, sentence):
113
+ inputs = self.tokenizer(
114
+ sentence, truncation=True, padding='max_length', max_length=128, return_tensors="np"
115
+ )
116
+ ort_inputs = {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask']}
117
+ try:
118
+ outputs = self.model.run(None, ort_inputs)
119
+ logits = outputs[0]
120
+ probs = torch.softmax(torch.tensor(logits), dim=1)
121
+ pred_id = torch.argmax(probs).item()
122
+ return {
123
+ 'label': self.id2label[pred_id],
124
+ 'confidence': round(torch.max(probs).item(), 4)
125
+ }
126
+ except Exception as e:
127
+ print(f"❌ Inference failed for sentence: '{sentence}'. Error: {str(e)}")
128
+ return {'label': 'ERROR', 'confidence': 0.0}
129
+
130
+ def evaluate_full_text(self, text):
131
+ key_sentences_info = self._extract_key_sentences(text)
132
+ if not key_sentences_info:
133
+ return {'status': 'no_key_sentences', 'message': '未检测到包含评价关键词的有效句子。'}
134
+
135
+ processed_sentences = []
136
+ for sent_info in key_sentences_info:
137
+ prediction = self._predict_sentence(sent_info['sentence'])
138
+ sent_info.update(prediction)
139
+ processed_sentences.append(sent_info)
140
+
141
+ overall_stats = defaultdict(lambda: defaultdict(int))
142
+ all_labels = [sent['label'] for sent in processed_sentences]
143
+ overall_stats['total_sentences'] = len(processed_sentences)
144
+ overall_stats['label_distribution'] = dict(Counter(all_labels))
145
+ overall_stats['avg_confidence'] = round(
146
+ sum(s['confidence'] for s in processed_sentences) / len(processed_sentences),
147
+ 4) if processed_sentences else 0
148
+
149
+ for sent in processed_sentences:
150
+ for info in sent['matched_info']:
151
+ overall_stats['scene_distribution'][info['scene']] += 1
152
+ overall_stats['sentiment_distribution'][info['sentiment']] += 1
153
+
154
+ return {
155
+ 'status': 'success',
156
+ 'overall_stats': dict(overall_stats),
157
+ 'key_sentences': processed_sentences
158
+ }