LbejchJakub commited on
Commit
193fd12
·
verified ·
1 Parent(s): e940d52

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +21 -20
  2. app.py +29 -0
  3. model.py +161 -0
  4. requirements.txt +6 -3
Dockerfile CHANGED
@@ -1,20 +1,21 @@
1
- FROM python:3.13.5-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- git \
9
- && rm -rf /var/lib/apt/lists/*
10
-
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
13
-
14
- RUN pip3 install -r requirements.txt
15
-
16
- EXPOSE 8501
17
-
18
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
-
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
1
+ # Použijeme oficiální Python image
2
+ FROM python:3.10-slim
3
+
4
+ # Nastavíme pracovní adresář v kontejneru
5
+ WORKDIR /app
6
+
7
+ # Zkopírujeme soubor se závislostmi
8
+ COPY requirements.txt ./
9
+
10
+ # Nainstalujeme závislosti
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Zkopírujeme zbytek kódu aplikace
14
+ COPY . .
15
+
16
+ # Vystavíme port, na kterém poběží Streamlit
17
+ EXPOSE 8080
18
+
19
+ # Příkaz, který se spustí při startu kontejneru
20
+ # Spustí Streamlit aplikaci na portu 8080
21
+ CMD ["streamlit", "run", "app.py", "--server.port=8080", "--server.address=0.0.0.0"]
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================
2
+ # app.py
3
+ # =========================
4
+ import streamlit as st
5
+ import pandas as pd
6
+ from model import process_headlines
7
+
8
+ st.set_page_config(layout="wide")
9
+ st.title("🧪 Detektor clickbaitu")
10
+ st.markdown("Vložte jeden nebo více titulků (každý na nový řádek) a klikněte na 'Analyzovat'.")
11
+
12
+ # Vstupní pole pro text
13
+ input_text = st.text_area("Zadejte titulky:", height=200, placeholder="Např.:\nŠokující odhalení!\nToto neuvěříte!\nBěžná zpráva o počasí.")
14
+
15
+ # Tlačítko pro spuštění analýzy
16
+ if st.button("Analyzovat"):
17
+ if input_text.strip():
18
+ # Rozdělení textu na řádky a odstranění prázdných
19
+ headlines = [line.strip() for line in input_text.split('\n') if line.strip()]
20
+
21
+ with st.spinner("Probíhá analýza... Modely se poprvé stahují, může to trvat i několik minut."):
22
+ try:
23
+ results_df = process_headlines(headlines)
24
+ st.success("Analýza dokončena!")
25
+ st.dataframe(results_df)
26
+ except Exception as e:
27
+ st.error(f"Při analýze došlo k chybě: {e}")
28
+ else:
29
+ st.warning("Zadejte prosím alespoň jeden titulek.")
model.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================
2
+ # model.py
3
+ # =========================
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ from transformers import (
9
+ AutoModelForPreTraining,
10
+ AutoTokenizer,
11
+ pipeline,
12
+ )
13
+ import streamlit as st # Přidáme pro cachování
14
+
15
+ # =========================
16
+ # CONFIG (stejné jako u vás)
17
+ # =========================
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ ELECTRA_MODEL = "Seznam/small-e-czech"
20
+ CLF_MODEL = "Stremie/xlm-roberta-base-clickbait"
21
+
22
+ RTD_CLICKBAIT_TH = 0.20
23
+ RTD_BORDERLINE_TH = 0.15
24
+ CLF_CLICK_TH = 0.65
25
+ CLF_NOT_TH = 0.35
26
+ COMB_CLICK_TH = 0.60
27
+ COMB_NOT_TH = 0.40
28
+
29
+ # =========================
30
+ # LOAD MODELS (s cachováním)
31
+ # =========================
32
+ # Použijeme @st.cache_resource, aby se modely načetly jen jednou
33
+ @st.cache_resource
34
+ def load_models():
35
+ """Načte a vrátí oba modely a tokenizer."""
36
+ print("Načítám modely...")
37
+ disc = AutoModelForPreTraining.from_pretrained(ELECTRA_MODEL).to(DEVICE).eval()
38
+ tok = AutoTokenizer.from_pretrained(ELECTRA_MODEL)
39
+
40
+ clf = pipeline(
41
+ "text-classification",
42
+ model=CLF_MODEL,
43
+ device=0 if DEVICE == "cuda" else -1
44
+ )
45
+
46
+ # ---- Robust label mapping pro klasifikátor ----
47
+ id2label = getattr(clf.model.config, "id2label", {}) or {}
48
+ label_values_upper = {str(v).upper() for v in id2label.values()}
49
+ if not ({"CLICKBAIT", "NOT"} <= label_values_upper):
50
+ clf.model.config.id2label = {0: "NOT", 1: "CLICKBAIT"}
51
+ clf.model.config.label2id = {"NOT": 0, "CLICKBAIT": 1}
52
+
53
+ print("Modely načteny.")
54
+ return disc, tok, clf
55
+
56
+ # Všechny vaše ostatní funkce (rtd_token_scores_batch, classify_supervised, atd.)
57
+ # zde zkopírujte BEZE ZMĚN.
58
+ # ... (vložte sem zbytek funkcí z vašeho skriptu) ...
59
+ @torch.no_grad()
60
+ def rtd_token_scores_batch(texts, disc, tok, batch_size=32):
61
+ all_scores = []
62
+ for i in range(0, len(texts), batch_size):
63
+ enc = tok(texts[i:i+batch_size], return_tensors="pt", padding=True, truncation=True).to(DEVICE)
64
+ out = disc(**enc)
65
+ probs = torch.sigmoid(out.logits).detach().cpu().numpy()
66
+ all_scores.extend(probs)
67
+ return all_scores
68
+
69
+ def clickbait_score_rtd_from_probs(probs, k_top: int = 5) -> float:
70
+ core = probs[1:-1] if len(probs) >= 2 else probs
71
+ if core.size == 0: return 0.0
72
+ k = min(k_top, core.size)
73
+ topk = np.partition(core, -k)[-k:]
74
+ score = float(np.mean(topk))
75
+ return max(0.0, min(1.0, score))
76
+
77
+ def rtd_label_from_score(p: float) -> str:
78
+ if p >= RTD_CLICKBAIT_TH: return "CLICKBAIT"
79
+ if p >= RTD_BORDERLINE_TH: return "BORDERLINE"
80
+ return "NOT"
81
+
82
+ def _normalize_label_to_index(lbl, LABEL2ID):
83
+ if isinstance(lbl, int): return lbl
84
+ s = str(lbl)
85
+ if s in LABEL2ID: return LABEL2ID[s]
86
+ m = re.search(r"(\d+)$", s)
87
+ if m: return int(m.group(1))
88
+ return None
89
+
90
+ def classify_supervised(texts, clf):
91
+ ID2LABEL = clf.model.config.id2label
92
+ LABEL2ID = clf.model.config.label2id
93
+ sanitized = [str(t).strip() if pd.notna(t) else "" for t in texts]
94
+ outs = clf(sanitized, top_k=None, truncation=True, max_length=256)
95
+ results = []
96
+ for scores in outs:
97
+ prob_click, prob_not = 0.0, 0.0
98
+ for s in scores:
99
+ idx = _normalize_label_to_index(s["label"], LABEL2ID)
100
+ if idx is None: continue
101
+ name = ID2LABEL.get(idx, str(s["label"])).upper()
102
+ if name == "CLICKBAIT": prob_click = float(s["score"])
103
+ elif name == "NOT": prob_not = float(s["score"])
104
+
105
+ binary_label = "CLICKBAIT" if prob_click >= prob_not else "NOT"
106
+ if prob_click >= CLF_CLICK_TH: tri_label = "CLICKBAIT"
107
+ elif prob_click <= CLF_NOT_TH: tri_label = "NOT"
108
+ else: tri_label = "BORDERLINE"
109
+ clf_margin = abs(prob_click - prob_not)
110
+ results.append({
111
+ "clf_prob_clickbait": prob_click, "clf_prob_not": prob_not,
112
+ "clf_label": binary_label, "clf_label_3way": tri_label,
113
+ "clf_margin": clf_margin,
114
+ })
115
+ return results
116
+
117
+
118
+ # =========================
119
+ # HLAVNÍ FUNKCE PRO ZPRACOVÁNÍ
120
+ # =========================
121
+ def process_headlines(headlines: list[str], k_top: int = 5) -> pd.DataFrame:
122
+ """Zpracuje seznam titulků a vrátí DataFrame s výsledky."""
123
+ if not headlines or all(s.isspace() for s in headlines):
124
+ return pd.DataFrame()
125
+
126
+ disc, tok, clf = load_models()
127
+ df = pd.DataFrame({"Titulek": headlines})
128
+
129
+ # RTD
130
+ rtd_probs_all = rtd_token_scores_batch(headlines, disc, tok, batch_size=32)
131
+ rtd_scores = [clickbait_score_rtd_from_probs(p, k_top=k_top) for p in rtd_probs_all]
132
+ rtd_labels = [rtd_label_from_score(p) for p in rtd_scores]
133
+
134
+ # Supervised
135
+ sup_rows = classify_supervised(headlines, clf)
136
+ df_sup = pd.DataFrame(sup_rows)
137
+
138
+ # Sestavení výsledků
139
+ df_out = df.copy()
140
+ df_out["rtd_score"] = rtd_scores
141
+ df_out["rtd_label"] = rtd_labels
142
+ df_out = pd.concat([df_out, df_sup], axis=1)
143
+
144
+ df_out["combined_score"] = (0.85 * df_out["clf_prob_clickbait"] + 0.15 * df_out["rtd_score"])
145
+
146
+ final_labels = []
147
+ for s in df_out["combined_score"]:
148
+ if s >= COMB_CLICK_TH: final_labels.append("CLICKBAIT")
149
+ elif s <= COMB_NOT_TH: final_labels.append("NOT")
150
+ else: final_labels.append("BORDERLINE")
151
+ df_out["final_label"] = final_labels
152
+
153
+ # Vybereme a přejmenujeme sloupce pro přehlednost
154
+ final_cols = {
155
+ "Titulek": "Titulek",
156
+ "final_label": "Výsledek",
157
+ "combined_score": "Kombinované skóre",
158
+ "clf_prob_clickbait": "Pravděpodobnost clickbaitu",
159
+ "rtd_score": "RTD skóre",
160
+ }
161
+ return df_out[final_cols.keys()].rename(columns=final_cols)
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
1
+ pandas==2.2.2
2
+ torch --index-url https://download.pytorch.org/whl/cpu
3
+ transformers
4
+ accelerate
5
+ streamlit
6
+ sentencepiece # Požadováno některými tokenizéry