Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from sklearn.preprocessing import MultiLabelBinarizer | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.multioutput import MultiOutputClassifier | |
| from sklearn.metrics import classification_report, f1_score, accuracy_score, hamming_loss | |
| import gradio as gr | |
| # Load dataset | |
| splits = {'train': 'simplified/train-00000-of-00001.parquet'} | |
| df = pd.read_parquet("hf://datasets/google-research-datasets/go_emotions/" + splits["train"]) | |
| emotion_labels = [ | |
| "admiration", "amusement", "anger", "annoyance", "approval", | |
| "caring", "confusion", "curiosity", "desire", "disappointment", | |
| "disapproval", "disgust", "embarrassment", "excitement", "fear", | |
| "gratitude", "grief", "joy", "love", "nervousness", | |
| "optimism", "pride", "realization", "relief", "remorse", | |
| "sadness", "surprise", "neutral" | |
| ] | |
| index_to_emotion = {i: label for i, label in enumerate(emotion_labels)} | |
| mlb = MultiLabelBinarizer(classes=range(28)) | |
| y = mlb.fit_transform(df['labels']) | |
| vectorizer = TfidfVectorizer(max_features=5000) | |
| X = vectorizer.fit_transform(df['text']) | |
| # Placeholder for trained model | |
| model = None | |
| metrics_report = "" | |
| def train_model(test_size=0.2, max_iter=1000, random_state=42): | |
| global model, metrics_report | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=test_size, random_state=random_state | |
| ) | |
| model = MultiOutputClassifier(LogisticRegression(max_iter=max_iter)) | |
| model.fit(X_train, y_train) | |
| y_pred = model.predict(X_test) | |
| # Calculate standard classification report + other metrics | |
| report = classification_report( | |
| y_test, y_pred, target_names=[str(i) for i in range(28)] | |
| ) | |
| micro_f1 = f1_score(y_test, y_pred, average="micro") | |
| macro_f1 = f1_score(y_test, y_pred, average="macro") | |
| acc = accuracy_score(y_test, y_pred) | |
| hamming = hamming_loss(y_test, y_pred) | |
| metrics_summary = f""" | |
| Micro F1-score: {micro_f1:.4f} | |
| Macro F1-score: {macro_f1:.4f} | |
| Accuracy (Exact Match): {acc:.4f} | |
| Hamming Loss: {hamming:.4f} | |
| """ | |
| # Save the full report | |
| metrics_report = metrics_summary.strip() + "\n\n" + report | |
| return "Training Complete!" | |
| def predict_emotions(text): | |
| if model is None: | |
| return "Please train the model first.", "" | |
| vectorized = vectorizer.transform([text]) | |
| probas = model.predict_proba(vectorized) | |
| result = {} | |
| for i, emotion in enumerate(mlb.classes_): | |
| prob_class_1 = probas[i][0][1] | |
| result[emotion] = round(prob_class_1 * 100, 2) | |
| sorted_result = sorted(result.items(), key=lambda x: x[1], reverse=True) | |
| return sorted_result | |
| def predict_and_display(sentence): | |
| predictions = predict_emotions(sentence) | |
| if isinstance(predictions, str): | |
| return predictions, "" | |
| max_len = max(len(index_to_emotion[emo_id]) for emo_id, _ in predictions) | |
| result = "```" + "\nEmotion Predictions:\n\n" | |
| for emo_id, score in predictions: | |
| emo_name = index_to_emotion[emo_id] | |
| result += f"{emo_name.ljust(max_len)} → {score}%\n" | |
| result += "```" | |
| top_emotion = index_to_emotion[predictions[0][0]] | |
| return result, top_emotion | |
| # Gradio App | |
| with gr.Blocks(title="Interactive Emotion Detector", theme=gr.themes.Soft()) as demo: | |
| with gr.Tabs(): | |
| with gr.Tab("Emotion Detection"): | |
| gr.Markdown("## Emotion Detection") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| lines=3, placeholder="Enter a sentence...", label="Input Sentence" | |
| ) | |
| submit_btn = gr.Button("Analyze Emotion") | |
| with gr.Column(): | |
| output_text = gr.Markdown(label="Prediction Results") | |
| top_emotion = gr.Label(label="Top Emotion") | |
| submit_btn.click( | |
| fn=predict_and_display, | |
| inputs=input_text, | |
| outputs=[output_text, top_emotion] | |
| ) | |
| with gr.Tab("Dataset"): | |
| gr.Markdown("## Dataset Information") | |
| def dataset_info(): | |
| df = pd.read_parquet("hf://datasets/google-research-datasets/go_emotions/simplified/train-00000-of-00001.parquet") | |
| total_samples = len(df) | |
| emotions = sorted(set(e for label in df['labels'] for e in label)) | |
| emotion_names = [emotion_labels[i] for i in emotions] | |
| # Count distribution | |
| all_labels = [emotion_labels[i] for sublist in df['labels'] for i in sublist] | |
| label_counts = pd.Series(all_labels).value_counts().sort_index() | |
| label_df = pd.DataFrame({ | |
| "Emotion": label_counts.index, | |
| "Count": label_counts.values | |
| }) | |
| stats = f""" | |
| **Total Samples**: {total_samples} | |
| **Emotion Classes**: {', '.join(emotion_names)} | |
| """ | |
| return stats, label_df | |
| stats_display = gr.Markdown() | |
| dist_table = gr.Dataframe(headers=["Emotion", "Count"], interactive=False) | |
| load_btn = gr.Button("Load Dataset Info") | |
| load_btn.click(fn=dataset_info, inputs=[], outputs=[stats_display, dist_table]) | |
| with gr.Tab("EDA"): | |
| gr.Markdown("## Exploratory Data Analysis") | |
| eda_btn = gr.Button("Run EDA") | |
| eda_output = gr.Plot(label="EDA Output") | |
| def run_eda(): | |
| import matplotlib.pyplot as plt | |
| from collections import Counter | |
| import re | |
| # Define the label map inside the function | |
| label_map = [ | |
| 'admiration', 'amusement', 'anger', 'annoyance', 'approval', | |
| 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', | |
| 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', | |
| 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', | |
| 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', | |
| 'neutral' | |
| ] | |
| fig, axs = plt.subplots(2, 2, figsize=(18, 10)) | |
| # Label distribution | |
| label_counts = df['labels'].explode().value_counts().sort_index() | |
| axs[0, 0].bar(label_map, label_counts) | |
| axs[0, 0].set_title("Label Frequency Distribution") | |
| axs[0, 0].tick_params(axis='x', rotation=45) | |
| # Labels per example | |
| df['num_labels'] = df['labels'].apply(len) | |
| df['num_labels'].value_counts().sort_index().plot(kind='bar', ax=axs[0, 1]) | |
| axs[0, 1].set_title("Number of Labels per Example") | |
| # Text length distribution | |
| df['text_length'] = df['text'].apply(len) | |
| df['text_length'].hist(bins=50, ax=axs[1, 0]) | |
| axs[1, 0].set_title("Distribution of Text Lengths") | |
| axs[1, 0].set_xlabel("Text Length (characters)") | |
| axs[1, 0].set_ylabel("Frequency") | |
| # Most common words | |
| all_words = " ".join(df['text']).lower() | |
| tokens = re.findall(r'\b\w+\b', all_words) | |
| common_words = Counter(tokens).most_common(20) | |
| words, freqs = zip(*common_words) | |
| axs[1, 1].bar(words, freqs) | |
| axs[1, 1].set_title("Top 20 Most Common Words") | |
| axs[1, 1].tick_params(axis='x', rotation=45) | |
| plt.tight_layout() | |
| return fig | |
| eda_btn.click(fn=run_eda, inputs=[], outputs=eda_output) | |
| with gr.Tab("Train Model"): | |
| gr.Markdown("## Train Your Emotion Model") | |
| test_size = gr.Slider(0.1, 0.5, step=0.05, value=0.2, label="Test Size") | |
| max_iter = gr.Slider(100, 5000, step=100, value=1000, label="Max Iterations") | |
| random_state = gr.Number(value=42, label="Random State") | |
| train_button = gr.Button("Train Model") | |
| train_status = gr.Textbox(label="Training Status") | |
| train_button.click( | |
| fn=train_model, | |
| inputs=[test_size, max_iter, random_state], | |
| outputs=train_status | |
| ) | |
| with gr.Tab("Results"): | |
| gr.Markdown("## Evaluation Metrics") | |
| results_output = gr.Markdown(label="Classification Report") | |
| def get_report(): | |
| return "```\n" + metrics_report + "\n```" | |
| refresh_btn = gr.Button("Refresh Report") | |
| refresh_btn.click( | |
| fn=get_report, | |
| inputs=[], | |
| outputs=results_output | |
| ) | |
| demo.launch() | |