Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import base64 | |
| import fasttext | |
| import re | |
| import torch | |
| from transformers import AutoModelForSequenceClassification | |
| from transformers import BertTokenizerFast | |
| st.set_page_config( | |
| page_title="detoxi.ai", | |
| page_icon="./mini_logo1.png", | |
| layout="centered" | |
| ) | |
| # Кодируем логотип в base64 (для локальных файлов) | |
| def get_image_base64(path): | |
| with open(path, "rb") as img_file: | |
| return base64.b64encode(img_file.read()).decode() | |
| # Кэширование модели для ускорения работы | |
| def load_model(): | |
| model = fasttext.load_model('./model_fasttext.bin') | |
| return model | |
| model = load_model() | |
| bin_str = get_image_base64("./билли.png") | |
| page_bg_img = ''' | |
| <style> | |
| .stApp{ | |
| background-image: linear-gradient(rgba(255, 255, 255, 0.7), | |
| rgba(255, 255, 255, 0.7)), | |
| url("data:image/png;base64,%s"); | |
| background-size: cover; | |
| background-position: center; | |
| background-repeat: no-repeat; | |
| background-attachment: fixed; | |
| } | |
| </style> | |
| ''' % bin_str | |
| st.markdown(page_bg_img, unsafe_allow_html=True) | |
| logo_base64 = get_image_base64("./top_logo1.png") | |
| # Используем HTML для вставки логотипа в заголовок | |
| st.markdown( | |
| f""" | |
| <div style="display: flex; justify-content: center;"> | |
| <img src="data:image/png;base64,{logo_base64}" width="400"> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Описание | |
| st.write("""<p style='text-align: center; font-size: 24px;'>Это приложение сделает твою речь менее токсичной. | |
| И даже не придётся платить 300 bucks.</p>""", unsafe_allow_html=True) | |
| class ModelWrapper(object): | |
| MODELS_DIR: str = "./new_models/" | |
| MODEL_NAME: str = "model" | |
| TOKENIZER: str = "tokenizer" | |
| def __init__(self): | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| ModelWrapper.MODELS_DIR + ModelWrapper.MODEL_NAME, torchscript=True | |
| ) | |
| self.tokenizer = BertTokenizerFast.from_pretrained( | |
| ModelWrapper.MODELS_DIR + ModelWrapper.TOKENIZER | |
| ) | |
| self.id2label: dict[int, str] = {0: "__label__positive", 1: "__label__negative"} | |
| def __call__(self, text: str) -> str: | |
| max_input_length = ( | |
| self.model.config.max_position_embeddings | |
| ) # 512 for this model | |
| inputs = self.tokenizer( | |
| text, | |
| max_length=max_input_length, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| outputs = self.model( | |
| **inputs, return_dict=True | |
| ) # output is logits for huggingfcae transformers | |
| predicted = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| predicted_id = torch.argmax(predicted, dim=1).numpy()[0] | |
| return self.id2label[predicted_id] | |
| def highlight_obscene_words(text): | |
| label,_=model.predict(text.lower()) | |
| if label[0]=='__label__positive': | |
| st.markdown( | |
| "<span style='background:#47916B;'>{}|приемлемо</span>".format(text), | |
| unsafe_allow_html=True | |
| ) | |
| else: | |
| st.markdown( | |
| "<span style='background:#ffcccc;'>{}|токсично</span>".format(text), | |
| unsafe_allow_html=True | |
| ) | |
| # Боковая панель | |
| with st.sidebar: | |
| st.header("""О приложении""") | |
| st.write(""" | |
| Это приложение, созданно для сдачи задания по ML. | |
| Оно показывает, чему мы научились за эту домашку: | |
| - Благославлять создателей hugging face | |
| - Писать прототипы приложений с помощью библиотеки Streamlit | |
| - Дружно работать в команде | |
| """, unsafe_allow_html=True) | |
| st.write("""<p style='text-align: center;'>Введите текст ниже, и приложение определит токсичность твоего предложения.</p>""", unsafe_allow_html=True) | |
| user_input = st.text_area('',height=200) | |
| if st.button("Проверить текст"): | |
| if user_input.strip(): | |
| st.subheader("Результат:") | |
| result = re.split(r'[.\n]+', user_input) | |
| result = [part for part in result if part.strip() != ""] | |
| if result!=[]: | |
| for text in result: | |
| highlight_obscene_words(text) | |
| else: | |
| st.warning("Пожалуйста, введите текст для проверки") |