# ========================= # model.py # ========================= import re import torch import numpy as np import pandas as pd from transformers import ( AutoModelForPreTraining, AutoTokenizer, pipeline, ) import streamlit as st # Přidáme pro cachování # ========================= # CONFIG (stejné jako u vás) # ========================= DEVICE = "cuda" if torch.cuda.is_available() else "cpu" ELECTRA_MODEL = "Seznam/small-e-czech" CLF_MODEL = "Stremie/xlm-roberta-base-clickbait" RTD_CLICKBAIT_TH = 0.20 RTD_BORDERLINE_TH = 0.15 CLF_CLICK_TH = 0.65 CLF_NOT_TH = 0.35 COMB_CLICK_TH = 0.60 COMB_NOT_TH = 0.40 # ========================= # LOAD MODELS (s cachováním) # ========================= # Použijeme @st.cache_resource, aby se modely načetly jen jednou @st.cache_resource def load_models(): """Načte a vrátí oba modely a tokenizer.""" print("Načítám modely...") disc = AutoModelForPreTraining.from_pretrained(ELECTRA_MODEL).to(DEVICE).eval() tok = AutoTokenizer.from_pretrained(ELECTRA_MODEL) clf = pipeline( "text-classification", model=CLF_MODEL, device=0 if DEVICE == "cuda" else -1 ) # ---- Robust label mapping pro klasifikátor ---- id2label = getattr(clf.model.config, "id2label", {}) or {} label_values_upper = {str(v).upper() for v in id2label.values()} if not ({"CLICKBAIT", "NOT"} <= label_values_upper): clf.model.config.id2label = {0: "NOT", 1: "CLICKBAIT"} clf.model.config.label2id = {"NOT": 0, "CLICKBAIT": 1} print("Modely načteny.") return disc, tok, clf # Všechny vaše ostatní funkce (rtd_token_scores_batch, classify_supervised, atd.) # zde zkopírujte BEZE ZMĚN. # ... (vložte sem zbytek funkcí z vašeho skriptu) ... @torch.no_grad() def rtd_token_scores_batch(texts, disc, tok, batch_size=32): all_scores = [] for i in range(0, len(texts), batch_size): enc = tok(texts[i:i+batch_size], return_tensors="pt", padding=True, truncation=True).to(DEVICE) out = disc(**enc) probs = torch.sigmoid(out.logits).detach().cpu().numpy() all_scores.extend(probs) return all_scores def clickbait_score_rtd_from_probs(probs, k_top: int = 5) -> float: core = probs[1:-1] if len(probs) >= 2 else probs if core.size == 0: return 0.0 k = min(k_top, core.size) topk = np.partition(core, -k)[-k:] score = float(np.mean(topk)) return max(0.0, min(1.0, score)) def rtd_label_from_score(p: float) -> str: if p >= RTD_CLICKBAIT_TH: return "CLICKBAIT" if p >= RTD_BORDERLINE_TH: return "BORDERLINE" return "NOT" def _normalize_label_to_index(lbl, LABEL2ID): if isinstance(lbl, int): return lbl s = str(lbl) if s in LABEL2ID: return LABEL2ID[s] m = re.search(r"(\d+)$", s) if m: return int(m.group(1)) return None def classify_supervised(texts, clf): ID2LABEL = clf.model.config.id2label LABEL2ID = clf.model.config.label2id sanitized = [str(t).strip() if pd.notna(t) else "" for t in texts] outs = clf(sanitized, top_k=None, truncation=True, max_length=256) results = [] for scores in outs: prob_click, prob_not = 0.0, 0.0 for s in scores: idx = _normalize_label_to_index(s["label"], LABEL2ID) if idx is None: continue name = ID2LABEL.get(idx, str(s["label"])).upper() if name == "CLICKBAIT": prob_click = float(s["score"]) elif name == "NOT": prob_not = float(s["score"]) binary_label = "CLICKBAIT" if prob_click >= prob_not else "NOT" if prob_click >= CLF_CLICK_TH: tri_label = "CLICKBAIT" elif prob_click <= CLF_NOT_TH: tri_label = "NOT" else: tri_label = "BORDERLINE" clf_margin = abs(prob_click - prob_not) results.append({ "clf_prob_clickbait": prob_click, "clf_prob_not": prob_not, "clf_label": binary_label, "clf_label_3way": tri_label, "clf_margin": clf_margin, }) return results # ========================= # HLAVNÍ FUNKCE PRO ZPRACOVÁNÍ # ========================= def process_headlines(headlines: list[str], k_top: int = 5) -> pd.DataFrame: """Zpracuje seznam titulků a vrátí DataFrame s výsledky.""" if not headlines or all(s.isspace() for s in headlines): return pd.DataFrame() disc, tok, clf = load_models() df = pd.DataFrame({"Titulek": headlines}) # RTD rtd_probs_all = rtd_token_scores_batch(headlines, disc, tok, batch_size=32) rtd_scores = [clickbait_score_rtd_from_probs(p, k_top=k_top) for p in rtd_probs_all] rtd_labels = [rtd_label_from_score(p) for p in rtd_scores] # Supervised sup_rows = classify_supervised(headlines, clf) df_sup = pd.DataFrame(sup_rows) # Sestavení výsledků df_out = df.copy() df_out["rtd_score"] = rtd_scores df_out["rtd_label"] = rtd_labels df_out = pd.concat([df_out, df_sup], axis=1) df_out["combined_score"] = (0.85 * df_out["clf_prob_clickbait"] + 0.15 * df_out["rtd_score"]) final_labels = [] for s in df_out["combined_score"]: if s >= COMB_CLICK_TH: final_labels.append("CLICKBAIT") elif s <= COMB_NOT_TH: final_labels.append("NOT") else: final_labels.append("BORDERLINE") df_out["final_label"] = final_labels # Vybereme a přejmenujeme sloupce pro přehlednost final_cols = { "Titulek": "Titulek", "final_label": "Výsledek", "combined_score": "Kombinované skóre", "clf_prob_clickbait": "Pravděpodobnost clickbaitu", "rtd_score": "RTD skóre", } return df_out[final_cols.keys()].rename(columns=final_cols)