# Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/ import constants import torch import pandas as pd import streamlit as st import matplotlib.pyplot as plt from transformers import AutoModelForSequenceClassification, AutoTokenizer from constants import DIALECTS import altair as alt from altair import X, Y, Scale import base64 import re def predict_binary_outcomes(model, tokenizer, text, threshold=0.3): """Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit. Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid. The model is expected to generate logits for each dialect of the following dialects in the same order: Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen. Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI. """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encodings = tokenizer( text, truncation=True, padding=True, max_length=128, return_tensors="pt" ) ## inputs input_ids = encodings["input_ids"].to(device) attention_mask = encodings["attention_mask"].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1).tolist() return probabilities # binary_predictions = probabilities.multiply((probabilities >= threshold).astype(int)) # return binary_predictions # Map indices to actual labels # predicted_dialects = [ # dialect # for dialect, dialect_prediction in zip(DIALECTS, binary_predictions) # if dialect_prediction == 1 # ] # return predicted_dialects def preprocess_text(arabic_text): """Apply preprocessing to the given Arabic text. Args: arabic_text: The Arabic text to be preprocessed. Returns: The preprocessed Arabic text. """ no_urls = re.sub( r"(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b", "", arabic_text, flags=re.MULTILINE, ) no_english = re.sub(r"[a-zA-Z]", "", no_urls) return no_english @st.cache_data def render_svg(svg): """Renders the given svg string.""" b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") html = rf'