import json import re import os import hashlib import onnxruntime as ort import numpy as np from typing import List, Dict, Set, Optional score_map = {'A': 5, 'B': 4, 'C': 3, 'D': 2, 'E': 1} class SentenceExtractor: def __init__( self, eval_keywords_path: str, model_path: str = "distilled_model.onnx", *, # 分句与聚合相关的可配置开关 merge_leading_punct: bool = True, min_sentence_char_len: int = 6, aggregation_mode: str = "max", # 可选:"max" | "mean" # 加减号阈值(>0 / <0 为原逻辑;建议适度提高到 2/-2) word_score_plus_threshold: int = 1, word_score_minus_threshold: int = -1, ): # 统一以文件所在目录为根,避免工作目录不同导致找不到资源 self.base_dir = os.path.dirname(os.path.abspath(__file__)) self.tokenizer_dir = self.base_dir # 允许传相对路径:自动转绝对 if not os.path.isabs(model_path): model_path = os.path.join(self.base_dir, model_path) if not os.path.isabs(eval_keywords_path): eval_keywords_path = os.path.join(self.base_dir, eval_keywords_path) self.eval_keywords = self._load_eval_keywords(eval_keywords_path) self.all_keywords = self._extract_all_keywords() self.ort_session = None self.input_name = None self.output_name = None # 配置项 self.merge_leading_punct = merge_leading_punct self.min_sentence_char_len = max(0, int(min_sentence_char_len)) self.aggregation_mode = aggregation_mode.lower().strip() if self.aggregation_mode not in {"max", "mean"}: self.aggregation_mode = "max" self.word_score_plus_threshold = int(word_score_plus_threshold) self.word_score_minus_threshold = int(word_score_minus_threshold) self.providers: Optional[List[str]] = None self.tokenizer_loaded: bool = False self.last_tokenizer_error: Optional[str] = None try: # 强制使用 CPU provider,避免某些环境下选择到不可用的 GPU provider 导致加载失败 self.ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) self.input_name = self.ort_session.get_inputs()[0].name self.output_name = self.ort_session.get_outputs()[0].name try: self.providers = self.ort_session.get_providers() except Exception: self.providers = None print("ONNX 模型加载成功") self.model_loaded: bool = True except Exception as e: print(f"ONNX 模型加载失败: {e}") self.ort_session = None self.model_loaded = False # 记录模型文件信息,便于排查“用错模型”问题 try: self.model_path_abs: Optional[str] = os.path.abspath(model_path) self.model_sha256: Optional[str] = None if os.path.exists(model_path): sha = hashlib.sha256() with open(model_path, 'rb') as f: for chunk in iter(lambda: f.read(8192), b''): sha.update(chunk) self.model_sha256 = sha.hexdigest() except Exception: self.model_path_abs = None self.model_sha256 = None def _preprocess_text(self, text: str) -> np.ndarray: try: from transformers import AutoTokenizer # 1) 优先从与脚本同目录加载本地 tokenizer(部署一起带上 tokenizer.json 等文件) try: tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir, local_files_only=True) except Exception: try: tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_dir) except Exception: # 2) 兜底:在线模型(需要外网) tokenizer = AutoTokenizer.from_pretrained("uer/chinese_roberta_L-4_H-256") inputs = tokenizer( text, truncation=True, padding=True, max_length=512, return_tensors='np' ) self.tokenizer_loaded = True self.last_tokenizer_error = None return inputs except Exception as e: self.tokenizer_loaded = False self.last_tokenizer_error = str(e) # 继续抛出异常,由上层捕获并回退,同时记录原因 raise def _predict_grade_with_model(self, text: str) -> Dict[str, any]: try: if not self.ort_session: word_score = self._calculate_word_scores(text)["total_score"] grade = "C" if word_score > 1: grade = "B" if word_score < -1: grade = "D" return {"grade": grade, "source": "rule", "word_score_total": word_score} inputs = self._preprocess_text(text) model_input_names = [i.name for i in self.ort_session.get_inputs()] input_data = {} if isinstance(inputs, dict) and 'input_ids' in inputs: token_type = inputs.get('token_type_ids') attn = inputs.get('attention_mask') ids = inputs['input_ids'] for name in model_input_names: lowered = name.lower() if 'mask' in lowered: input_data[name] = attn if attn is not None else np.ones_like(ids) elif 'token_type' in lowered or 'segment' in lowered: if token_type is None: token_type = np.zeros_like(ids) input_data[name] = token_type elif 'input_ids' in lowered or 'input' in lowered or 'ids' in lowered: input_data[name] = ids else: input_data[name] = np.zeros_like(ids) else: target_input = self.input_name or (model_input_names[0] if model_input_names else 'input') input_data = {target_input: inputs} outputs = self.ort_session.run([self.output_name], input_data) predictions = outputs[0] grade_index = int(np.argmax(predictions)) grades = ['A', 'B', 'C', 'D', 'E'] probs = self._softmax(predictions)[0].tolist() return { "grade": grades[grade_index], "source": "model", "prob": float(probs[grade_index]), "probs": probs, "logits": predictions[0].tolist(), } except Exception as e: print(f"模型预测出错: {e}") word_score = self._calculate_word_scores(text)["total_score"] grade = "C" if word_score > 1: grade = "B" if word_score < -1: grade = "D" return { "grade": grade, "source": "rule", "word_score_total": word_score, "reason": str(e), "tokenizer_loaded": self.tokenizer_loaded, "last_tokenizer_error": self.last_tokenizer_error, } @staticmethod def _softmax(x: np.ndarray) -> np.ndarray: x = x - np.max(x, axis=-1, keepdims=True) exp_x = np.exp(x) return exp_x / np.sum(exp_x, axis=-1, keepdims=True) def _load_eval_keywords(self, file_path: str) -> Dict[str, Dict[str, List[str]]]: try: with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: print(f"加载评估关键词库失败: {e}") return {} def _extract_all_keywords(self) -> Set[str]: keywords_set = set() for category, types in self.eval_keywords.items(): for _, keywords in types.items(): keywords_set.update(keywords) return keywords_set def _split_into_sentences(self, text: str) -> List[str]: if not text: return [] # 先按强终止符切分 normalized = re.sub(r'([。!?.!?])', r'\1\n', text) normalized = re.sub(r'[;;]\s*', ';\n', normalized) candidates = [s.strip() for s in re.split(r'[\r\n]+', normalized) if s.strip()] # 长句再按逗号细分 rough_sentences: List[str] = [] for s in candidates: if len(s) > 80 and not re.search(r'[。!?.!?;;]', s): parts = re.split(r'[,,]', s) rough_sentences.extend([p.strip() for p in parts if p.strip()]) else: rough_sentences.append(s) # 合并以标点开头的碎片,并过滤超短句 sentences: List[str] = [] leading_punct_pattern = r'^[,,。;;::、\s]+' for s in rough_sentences: if self.merge_leading_punct and re.match(leading_punct_pattern, s): # 去掉前缀标点后并入上一句 cleaned = re.sub(leading_punct_pattern, '', s) if sentences: sentences[-1] = f"{sentences[-1]}{cleaned}" else: if cleaned: sentences.append(cleaned) continue # 过滤极短句(去标点长度) plain = re.sub(r'[,,。;;::、!!??\s]', '', s) if self.min_sentence_char_len > 0 and len(plain) < self.min_sentence_char_len: # 不直接丢弃:若有上一句,合并 if sentences: sentences[-1] = f"{sentences[-1]}{s}" else: sentences.append(s) continue sentences.append(s) return [s.strip() for s in sentences if s and s.strip()] def _fuzzy_match_keyword(self, sentence: str, keyword: str) -> bool: """更严格的中文关键词匹配。 - 长度 < 2 的关键词(如“好”)仅按分词后的精确词匹配,避免所有句子都命中。 - 其余关键词采用去标点后的包含匹配。 """ if not keyword: return False # 统一去空白 sentence = sentence.strip() keyword = keyword.strip() # 对极短关键词走分词精确匹配,避免过拟合 if len(keyword) < 2: try: import jieba # 已在 requirements 中 tokens = set(jieba.lcut(sentence)) return keyword in tokens except Exception: # 兜底:对极短词不做模糊匹配 return False # 一般关键词:去标点后做包含匹配 import string trans = str.maketrans('', '', string.punctuation) sentence_clean = sentence.translate(trans) keyword_clean = keyword.translate(trans) return keyword_clean in sentence_clean def _is_negated_positive(self, text: str, keyword: str) -> bool: """检测积极关键词是否被否定词修饰,例如: - 没有/无/不/非/未/并不/毫无 + 关键词 - 对以“有”开头的积极词(如“有创新性”),也匹配“没有/无/不/未/并不/毫无 + 去掉‘有’后的部分(如“创新性”)” - 缺乏/不足/欠缺/缺少/不具备 + 关键词 或 关键词去“有”后的部分 """ if not keyword: return False sentence = text.strip() neg_prefixes = [ "没有", "没", "无", "不", "非", "未", "并不", "并没有", "并无", "毫无" ] lack_prefixes = [ "缺乏", "不足", "欠缺", "缺少", "不具备", "不够" ] # 构建安全的正则片段 import re def any_prefix(prefixes: List[str]) -> str: return "(?:" + "|".join(re.escape(p) for p in prefixes) + ")" patterns: List[str] = [] # 直接:否定前缀 + 关键词 patterns.append(rf"{any_prefix(neg_prefixes)}\s*{re.escape(keyword)}") # 直接:缺乏类前缀 + 关键词 patterns.append(rf"{any_prefix(lack_prefixes)}\s*{re.escape(keyword)}") # 若积极词以“有”开头,额外匹配去掉“有”的尾部(例如 ‘有创新性’ → ‘创新性’) if keyword.startswith("有") and len(keyword) > 1: tail = keyword[1:] patterns.append(rf"{any_prefix(neg_prefixes)}\s*{re.escape(tail)}") patterns.append(rf"{any_prefix(lack_prefixes)}\s*{re.escape(tail)}") for pat in patterns: if re.search(pat, sentence): return True return False def _extract_relevant_sentences(self, text: str) -> List[str]: sentences = self._split_into_sentences(text) relevant_sentences = [] for sentence in sentences: for category in ["student_performance", "content_quality", "cross_scene"]: if category not in self.eval_keywords: continue for sentiment in ["positive", "negative", "nature", "suggestion"]: if sentiment not in self.eval_keywords[category]: continue for keyword in self.eval_keywords[category][sentiment]: if self._fuzzy_match_keyword(sentence, keyword): if sentence not in relevant_sentences: relevant_sentences.append(sentence) break else: continue break else: continue break return relevant_sentences def _calculate_word_scores(self, text: str) -> Dict[str, int]: positive_count = 0 negative_count = 0 neutral_count = 0 total_score = 0 for category in ["student_performance", "content_quality", "cross_scene"]: if category not in self.eval_keywords: continue for keyword in self.eval_keywords[category].get("positive", []): if self._fuzzy_match_keyword(text, keyword): # 遇到被否定的积极词(如“没有创新性”含“有创新性”),按消极计分 if self._is_negated_positive(text, keyword): negative_count += 1 total_score -= 1 else: positive_count += 1 total_score += 1 for keyword in self.eval_keywords[category].get("negative", []): if self._fuzzy_match_keyword(text, keyword): negative_count += 1 total_score -= 1 for keyword in self.eval_keywords[category].get("nature", []): if self._fuzzy_match_keyword(text, keyword): neutral_count += 1 return { "positive_count": positive_count, "negative_count": negative_count, "neutral_count": neutral_count, "total_score": total_score, } def extract(self, text: str) -> Dict[str, any]: if not text: return { "comprehensive_grade": "C", "positive_word_count": 0, "negative_word_count": 0, "neutral_word_count": 0, "scored_sentences": [], "count": 0, } relevant_sentences = self._extract_relevant_sentences(text) scored_sentences = [] total_sentence_score = 0 for sentence in relevant_sentences: info = self._predict_grade_with_model(sentence) grade = info.get("grade", "C") score = score_map.get(grade, 3) # 附带调试信息 scored_sentences.append({ "sentence": sentence, "grade": grade, "source": info.get("source", "unknown"), "prob": info.get("prob"), "word_score_total": info.get("word_score_total"), }) total_sentence_score += score comprehensive_grade = "C" if relevant_sentences: reverse_map = {5: 'A', 4: 'B', 3: 'C', 2: 'D', 1: 'E'} if self.aggregation_mode == "max": # 取最高等级(更鲁棒,避免短碎句拉低均值) max_score = max(score_map.get(item["grade"], 3) for item in scored_sentences) comprehensive_grade = reverse_map.get(max_score, "C") else: avg_score = total_sentence_score / len(relevant_sentences) rounded_score = int(round(avg_score)) comprehensive_grade = reverse_map.get(rounded_score, "C") word_scores = self._calculate_word_scores(text) final_grade = comprehensive_grade if word_scores["total_score"] > self.word_score_plus_threshold: final_grade = comprehensive_grade + "+" elif word_scores["total_score"] < self.word_score_minus_threshold: final_grade = comprehensive_grade + "-" return { "comprehensive_grade": final_grade, "positive_word_count": word_scores["positive_count"], "negative_word_count": word_scores["negative_count"], "neutral_word_count": word_scores["neutral_count"], "scored_sentences": scored_sentences, "count": len(relevant_sentences), # 调试字段 "debug": { "model_loaded": getattr(self, "model_loaded", False), "model_path_abs": getattr(self, "model_path_abs", None), "model_sha256": getattr(self, "model_sha256", None), "providers": self.providers, "tokenizer_loaded": self.tokenizer_loaded, "last_tokenizer_error": self.last_tokenizer_error, "aggregation_mode": self.aggregation_mode, "min_sentence_char_len": self.min_sentence_char_len, "merge_leading_punct": self.merge_leading_punct, "word_score_plus_threshold": self.word_score_plus_threshold, "word_score_minus_threshold": self.word_score_minus_threshold, "relevant_sentences": relevant_sentences, "word_score_total": word_scores["total_score"], } }