Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- Dockerfile +21 -20
- app.py +29 -0
- model.py +161 -0
- requirements.txt +6 -3
Dockerfile
CHANGED
|
@@ -1,20 +1,21 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
| 1 |
+
# Použijeme oficiální Python image
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Nastavíme pracovní adresář v kontejneru
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Zkopírujeme soubor se závislostmi
|
| 8 |
+
COPY requirements.txt ./
|
| 9 |
+
|
| 10 |
+
# Nainstalujeme závislosti
|
| 11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 12 |
+
|
| 13 |
+
# Zkopírujeme zbytek kódu aplikace
|
| 14 |
+
COPY . .
|
| 15 |
+
|
| 16 |
+
# Vystavíme port, na kterém poběží Streamlit
|
| 17 |
+
EXPOSE 8080
|
| 18 |
+
|
| 19 |
+
# Příkaz, který se spustí při startu kontejneru
|
| 20 |
+
# Spustí Streamlit aplikaci na portu 8080
|
| 21 |
+
CMD ["streamlit", "run", "app.py", "--server.port=8080", "--server.address=0.0.0.0"]
|
app.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =========================
|
| 2 |
+
# app.py
|
| 3 |
+
# =========================
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from model import process_headlines
|
| 7 |
+
|
| 8 |
+
st.set_page_config(layout="wide")
|
| 9 |
+
st.title("🧪 Detektor clickbaitu")
|
| 10 |
+
st.markdown("Vložte jeden nebo více titulků (každý na nový řádek) a klikněte na 'Analyzovat'.")
|
| 11 |
+
|
| 12 |
+
# Vstupní pole pro text
|
| 13 |
+
input_text = st.text_area("Zadejte titulky:", height=200, placeholder="Např.:\nŠokující odhalení!\nToto neuvěříte!\nBěžná zpráva o počasí.")
|
| 14 |
+
|
| 15 |
+
# Tlačítko pro spuštění analýzy
|
| 16 |
+
if st.button("Analyzovat"):
|
| 17 |
+
if input_text.strip():
|
| 18 |
+
# Rozdělení textu na řádky a odstranění prázdných
|
| 19 |
+
headlines = [line.strip() for line in input_text.split('\n') if line.strip()]
|
| 20 |
+
|
| 21 |
+
with st.spinner("Probíhá analýza... Modely se poprvé stahují, může to trvat i několik minut."):
|
| 22 |
+
try:
|
| 23 |
+
results_df = process_headlines(headlines)
|
| 24 |
+
st.success("Analýza dokončena!")
|
| 25 |
+
st.dataframe(results_df)
|
| 26 |
+
except Exception as e:
|
| 27 |
+
st.error(f"Při analýze došlo k chybě: {e}")
|
| 28 |
+
else:
|
| 29 |
+
st.warning("Zadejte prosím alespoň jeden titulek.")
|
model.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =========================
|
| 2 |
+
# model.py
|
| 3 |
+
# =========================
|
| 4 |
+
import re
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from transformers import (
|
| 9 |
+
AutoModelForPreTraining,
|
| 10 |
+
AutoTokenizer,
|
| 11 |
+
pipeline,
|
| 12 |
+
)
|
| 13 |
+
import streamlit as st # Přidáme pro cachování
|
| 14 |
+
|
| 15 |
+
# =========================
|
| 16 |
+
# CONFIG (stejné jako u vás)
|
| 17 |
+
# =========================
|
| 18 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
ELECTRA_MODEL = "Seznam/small-e-czech"
|
| 20 |
+
CLF_MODEL = "Stremie/xlm-roberta-base-clickbait"
|
| 21 |
+
|
| 22 |
+
RTD_CLICKBAIT_TH = 0.20
|
| 23 |
+
RTD_BORDERLINE_TH = 0.15
|
| 24 |
+
CLF_CLICK_TH = 0.65
|
| 25 |
+
CLF_NOT_TH = 0.35
|
| 26 |
+
COMB_CLICK_TH = 0.60
|
| 27 |
+
COMB_NOT_TH = 0.40
|
| 28 |
+
|
| 29 |
+
# =========================
|
| 30 |
+
# LOAD MODELS (s cachováním)
|
| 31 |
+
# =========================
|
| 32 |
+
# Použijeme @st.cache_resource, aby se modely načetly jen jednou
|
| 33 |
+
@st.cache_resource
|
| 34 |
+
def load_models():
|
| 35 |
+
"""Načte a vrátí oba modely a tokenizer."""
|
| 36 |
+
print("Načítám modely...")
|
| 37 |
+
disc = AutoModelForPreTraining.from_pretrained(ELECTRA_MODEL).to(DEVICE).eval()
|
| 38 |
+
tok = AutoTokenizer.from_pretrained(ELECTRA_MODEL)
|
| 39 |
+
|
| 40 |
+
clf = pipeline(
|
| 41 |
+
"text-classification",
|
| 42 |
+
model=CLF_MODEL,
|
| 43 |
+
device=0 if DEVICE == "cuda" else -1
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# ---- Robust label mapping pro klasifikátor ----
|
| 47 |
+
id2label = getattr(clf.model.config, "id2label", {}) or {}
|
| 48 |
+
label_values_upper = {str(v).upper() for v in id2label.values()}
|
| 49 |
+
if not ({"CLICKBAIT", "NOT"} <= label_values_upper):
|
| 50 |
+
clf.model.config.id2label = {0: "NOT", 1: "CLICKBAIT"}
|
| 51 |
+
clf.model.config.label2id = {"NOT": 0, "CLICKBAIT": 1}
|
| 52 |
+
|
| 53 |
+
print("Modely načteny.")
|
| 54 |
+
return disc, tok, clf
|
| 55 |
+
|
| 56 |
+
# Všechny vaše ostatní funkce (rtd_token_scores_batch, classify_supervised, atd.)
|
| 57 |
+
# zde zkopírujte BEZE ZMĚN.
|
| 58 |
+
# ... (vložte sem zbytek funkcí z vašeho skriptu) ...
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def rtd_token_scores_batch(texts, disc, tok, batch_size=32):
|
| 61 |
+
all_scores = []
|
| 62 |
+
for i in range(0, len(texts), batch_size):
|
| 63 |
+
enc = tok(texts[i:i+batch_size], return_tensors="pt", padding=True, truncation=True).to(DEVICE)
|
| 64 |
+
out = disc(**enc)
|
| 65 |
+
probs = torch.sigmoid(out.logits).detach().cpu().numpy()
|
| 66 |
+
all_scores.extend(probs)
|
| 67 |
+
return all_scores
|
| 68 |
+
|
| 69 |
+
def clickbait_score_rtd_from_probs(probs, k_top: int = 5) -> float:
|
| 70 |
+
core = probs[1:-1] if len(probs) >= 2 else probs
|
| 71 |
+
if core.size == 0: return 0.0
|
| 72 |
+
k = min(k_top, core.size)
|
| 73 |
+
topk = np.partition(core, -k)[-k:]
|
| 74 |
+
score = float(np.mean(topk))
|
| 75 |
+
return max(0.0, min(1.0, score))
|
| 76 |
+
|
| 77 |
+
def rtd_label_from_score(p: float) -> str:
|
| 78 |
+
if p >= RTD_CLICKBAIT_TH: return "CLICKBAIT"
|
| 79 |
+
if p >= RTD_BORDERLINE_TH: return "BORDERLINE"
|
| 80 |
+
return "NOT"
|
| 81 |
+
|
| 82 |
+
def _normalize_label_to_index(lbl, LABEL2ID):
|
| 83 |
+
if isinstance(lbl, int): return lbl
|
| 84 |
+
s = str(lbl)
|
| 85 |
+
if s in LABEL2ID: return LABEL2ID[s]
|
| 86 |
+
m = re.search(r"(\d+)$", s)
|
| 87 |
+
if m: return int(m.group(1))
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
def classify_supervised(texts, clf):
|
| 91 |
+
ID2LABEL = clf.model.config.id2label
|
| 92 |
+
LABEL2ID = clf.model.config.label2id
|
| 93 |
+
sanitized = [str(t).strip() if pd.notna(t) else "" for t in texts]
|
| 94 |
+
outs = clf(sanitized, top_k=None, truncation=True, max_length=256)
|
| 95 |
+
results = []
|
| 96 |
+
for scores in outs:
|
| 97 |
+
prob_click, prob_not = 0.0, 0.0
|
| 98 |
+
for s in scores:
|
| 99 |
+
idx = _normalize_label_to_index(s["label"], LABEL2ID)
|
| 100 |
+
if idx is None: continue
|
| 101 |
+
name = ID2LABEL.get(idx, str(s["label"])).upper()
|
| 102 |
+
if name == "CLICKBAIT": prob_click = float(s["score"])
|
| 103 |
+
elif name == "NOT": prob_not = float(s["score"])
|
| 104 |
+
|
| 105 |
+
binary_label = "CLICKBAIT" if prob_click >= prob_not else "NOT"
|
| 106 |
+
if prob_click >= CLF_CLICK_TH: tri_label = "CLICKBAIT"
|
| 107 |
+
elif prob_click <= CLF_NOT_TH: tri_label = "NOT"
|
| 108 |
+
else: tri_label = "BORDERLINE"
|
| 109 |
+
clf_margin = abs(prob_click - prob_not)
|
| 110 |
+
results.append({
|
| 111 |
+
"clf_prob_clickbait": prob_click, "clf_prob_not": prob_not,
|
| 112 |
+
"clf_label": binary_label, "clf_label_3way": tri_label,
|
| 113 |
+
"clf_margin": clf_margin,
|
| 114 |
+
})
|
| 115 |
+
return results
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# =========================
|
| 119 |
+
# HLAVNÍ FUNKCE PRO ZPRACOVÁNÍ
|
| 120 |
+
# =========================
|
| 121 |
+
def process_headlines(headlines: list[str], k_top: int = 5) -> pd.DataFrame:
|
| 122 |
+
"""Zpracuje seznam titulků a vrátí DataFrame s výsledky."""
|
| 123 |
+
if not headlines or all(s.isspace() for s in headlines):
|
| 124 |
+
return pd.DataFrame()
|
| 125 |
+
|
| 126 |
+
disc, tok, clf = load_models()
|
| 127 |
+
df = pd.DataFrame({"Titulek": headlines})
|
| 128 |
+
|
| 129 |
+
# RTD
|
| 130 |
+
rtd_probs_all = rtd_token_scores_batch(headlines, disc, tok, batch_size=32)
|
| 131 |
+
rtd_scores = [clickbait_score_rtd_from_probs(p, k_top=k_top) for p in rtd_probs_all]
|
| 132 |
+
rtd_labels = [rtd_label_from_score(p) for p in rtd_scores]
|
| 133 |
+
|
| 134 |
+
# Supervised
|
| 135 |
+
sup_rows = classify_supervised(headlines, clf)
|
| 136 |
+
df_sup = pd.DataFrame(sup_rows)
|
| 137 |
+
|
| 138 |
+
# Sestavení výsledků
|
| 139 |
+
df_out = df.copy()
|
| 140 |
+
df_out["rtd_score"] = rtd_scores
|
| 141 |
+
df_out["rtd_label"] = rtd_labels
|
| 142 |
+
df_out = pd.concat([df_out, df_sup], axis=1)
|
| 143 |
+
|
| 144 |
+
df_out["combined_score"] = (0.85 * df_out["clf_prob_clickbait"] + 0.15 * df_out["rtd_score"])
|
| 145 |
+
|
| 146 |
+
final_labels = []
|
| 147 |
+
for s in df_out["combined_score"]:
|
| 148 |
+
if s >= COMB_CLICK_TH: final_labels.append("CLICKBAIT")
|
| 149 |
+
elif s <= COMB_NOT_TH: final_labels.append("NOT")
|
| 150 |
+
else: final_labels.append("BORDERLINE")
|
| 151 |
+
df_out["final_label"] = final_labels
|
| 152 |
+
|
| 153 |
+
# Vybereme a přejmenujeme sloupce pro přehlednost
|
| 154 |
+
final_cols = {
|
| 155 |
+
"Titulek": "Titulek",
|
| 156 |
+
"final_label": "Výsledek",
|
| 157 |
+
"combined_score": "Kombinované skóre",
|
| 158 |
+
"clf_prob_clickbait": "Pravděpodobnost clickbaitu",
|
| 159 |
+
"rtd_score": "RTD skóre",
|
| 160 |
+
}
|
| 161 |
+
return df_out[final_cols.keys()].rename(columns=final_cols)
|
requirements.txt
CHANGED
|
@@ -1,3 +1,6 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas==2.2.2
|
| 2 |
+
torch --index-url https://download.pytorch.org/whl/cpu
|
| 3 |
+
transformers
|
| 4 |
+
accelerate
|
| 5 |
+
streamlit
|
| 6 |
+
sentencepiece # Požadováno některými tokenizéry
|