Spaces:
Sleeping
Sleeping
| import os | |
| from datetime import datetime | |
| from pathlib import Path | |
| import huggingface_hub | |
| import jiwer | |
| import pandas as pd | |
| import requests | |
| import streamlit as st | |
| from huggingface_hub import HfFileSystem | |
| from manual_evlaution import render_manual_eval | |
| from st_fixed_container import st_fixed_container | |
| from substitutions_visualizer import visualize_substitutions | |
| from utils import display_rtl, ltr_tag | |
| from visual_eval.evaluator import HebrewTextNormalizer | |
| from visual_eval.visualization import render_visualize_jiwer_result_html | |
| HF_API_TOKEN = None | |
| try: | |
| HF_API_TOKEN = st.secrets["HF_API_TOKEN"] | |
| except FileNotFoundError: | |
| HF_API_TOKEN = os.environ.get("HF_API_TOKEN") | |
| has_api_token = HF_API_TOKEN is not None | |
| known_datasets = [ | |
| ("ivrit-ai/eval-d1:test:text", None, "ivrit_ai_eval_d1"), | |
| ("upai-inc/saspeech:test:text", None, "saspeech"), | |
| ("google/fleurs:test:transcription", "he_il", "fleurs"), | |
| ("mozilla-foundation/common_voice_17_0:test:sentence", "he", "common_voice_17"), | |
| ("imvladikon/hebrew_speech_kan:validation:sentence", None, "hebrew_speech_kan"), | |
| ] | |
| # Initialize session state for audio cache if it doesn't exist | |
| if "audio_cache" not in st.session_state: | |
| st.session_state.audio_cache = {} | |
| if "audio_preview_active" not in st.session_state: | |
| st.session_state.audio_preview_active = {} | |
| if "manual_mode" not in st.session_state: | |
| st.session_state.manual_mode = False | |
| if "uploaded_file" not in st.session_state: | |
| st.session_state.results_file = None | |
| if "consumed_query_lb_file" not in st.session_state: | |
| st.session_state.consumed_query_lb_file = None | |
| if "selected_entry_idx" not in st.session_state: | |
| st.session_state.selected_entry_idx = 0 | |
| if "total_entry_count" not in st.session_state: | |
| st.session_state.total_entry_count = 0 | |
| if "entry_page_size" not in st.session_state: | |
| st.session_state.entry_page_size = 20 | |
| def get_current_page_slice(): | |
| ss = st.session_state | |
| if ss.total_entry_count == 0: | |
| return slice(0, 0) | |
| page_first_entry = ( | |
| st.session_state.selected_entry_idx // ss.entry_page_size | |
| ) * ss.entry_page_size | |
| page_last_entry = min(page_first_entry + ss.entry_page_size, ss.total_entry_count) | |
| return slice(page_first_entry, page_last_entry) | |
| def page_navigation(): | |
| ss = st.session_state | |
| current_page_slice = get_current_page_slice() | |
| has_next_page = current_page_slice.stop < ss.total_entry_count - 1 | |
| has_prev_page = current_page_slice.start >= ss.entry_page_size | |
| col1, col2 = st.columns(2) | |
| if col1.button("Prev Page", disabled=not has_prev_page): | |
| ss.selected_entry_idx = current_page_slice.start - 1 | |
| st.rerun() | |
| if col2.button("Next Page", disabled=not has_next_page): | |
| ss.selected_entry_idx = current_page_slice.start + ss.entry_page_size | |
| st.rerun() | |
| def on_file_upload(): | |
| st.session_state.audio_cache = {} | |
| st.session_state.audio_preview_active = {} | |
| st.session_state.selected_entry_idx = 0 | |
| st.session_state.results_file = None | |
| def reset_upload_state(): | |
| if "lb_result_file" in st.query_params: | |
| st.query_params.pop("lb_result_file") | |
| st.session_state.consumed_query_lb_file = None | |
| on_file_upload() | |
| def get_leaderboard_result_csv_paths(root_search_path): | |
| fs = HfFileSystem(token=HF_API_TOKEN) | |
| found_files = fs.glob(f"{root_search_path}/*/*.csv") | |
| found_files_relative_paths = [f.split(root_search_path)[1] for f in found_files] | |
| return found_files_relative_paths | |
| def choose_input_file_from_leaderboard(): | |
| if not has_api_token: | |
| st.rerun() | |
| root_search_path = "ivrit-ai/hebrew-transcription-leaderboard/results" | |
| fsspec_spaces_root_search_path = f"spaces/{root_search_path}" | |
| found_files_relative_paths = get_leaderboard_result_csv_paths( | |
| fsspec_spaces_root_search_path | |
| ) | |
| selected_file = st.selectbox( | |
| "Select a CSV file from the leaderboard:", | |
| found_files_relative_paths, | |
| index=None, | |
| ) | |
| # Get the selected file | |
| if selected_file: | |
| paths_part = Path(selected_file).parent | |
| file_part = Path(selected_file).name | |
| uploaded_file = huggingface_hub.hf_hub_url( | |
| repo_id="ivrit-ai/hebrew-transcription-leaderboard", | |
| subfolder=f"results{paths_part}", | |
| filename=file_part, | |
| repo_type="space", | |
| ) | |
| on_file_upload() | |
| st.session_state.results_file = uploaded_file | |
| st.query_params["lb_result_file"] = selected_file | |
| st.session_state.consumed_query_lb_file = selected_file | |
| st.rerun() | |
| def read_results_csv(uploaded_file): | |
| with st.spinner("Loading results...", show_time=True): | |
| results_df = pd.read_csv(uploaded_file) | |
| return results_df | |
| def calculate_final_metrics(uploaded_file, _df): | |
| """Calculate final metrics for all entries | |
| Args: | |
| uploaded_file: The uploaded file object (For cache hash gen) | |
| _df: The dataframe containing the evaluation results (not included in cache hash) | |
| Returns: | |
| A dictionary containing the final metrics | |
| """ | |
| _df = _df.sort_values(by=["id"]) | |
| _df["reference_text"] = _df["reference_text"].fillna("") | |
| _df["predicted_text"] = _df["predicted_text"].fillna("") | |
| # convert to list of dicts | |
| entries_data = _df.to_dict(orient="records") | |
| htn = HebrewTextNormalizer() | |
| # Calculate final metrics | |
| results = jiwer.process_words( | |
| [htn(entry["reference_text"]) for entry in entries_data], | |
| [htn(entry["predicted_text"]) for entry in entries_data], | |
| ) | |
| return results | |
| def get_known_dataset_by_output_name(output_name): | |
| for dataset in known_datasets: | |
| if dataset[2] == output_name: | |
| return dataset | |
| return None | |
| def get_dataset_entries_audio_urls(dataset, offset=0, max_entries=100): | |
| if dataset is None or not has_api_token: | |
| return None | |
| dataset_repo_id, dataset_config, _ = dataset | |
| if not dataset_config: | |
| dataset_config = "default" | |
| if ":" in dataset_repo_id: | |
| dataset_repo_id, split, _ = dataset_repo_id.split(":") | |
| else: | |
| split = "test" | |
| headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} | |
| api_query_params = { | |
| "dataset": dataset_repo_id, | |
| "config": dataset_config, | |
| "split": split, | |
| "offset": offset, | |
| "length": max_entries, | |
| } | |
| query_params_str = "&".join([f"{k}={v}" for k, v in api_query_params.items()]) | |
| API_URL = f"/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%2Frows%3F%3Cspan class="hljs-subst">{query_params_str}" | |
| def query(): | |
| response = requests.get(API_URL, headers=headers) | |
| return response.json() | |
| data = query() | |
| def get_audio_url(row): | |
| audio_feature_list = row["row"]["audio"] | |
| first_audio = audio_feature_list[0] | |
| return first_audio["src"] | |
| if "rows" in data and len(data["rows"]) > 0: | |
| return [get_audio_url(row) for row in data["rows"]] | |
| else: | |
| return None | |
| def get_audio_url_for_entry( | |
| dataset, entry_idx, cache_neighbors=True, neighbor_range=20 | |
| ): | |
| """ | |
| Get audio URL for a specific entry and optionally cache neighbors | |
| Args: | |
| dataset: Dataset tuple (repo_id, config, output_name) | |
| entry_idx: Index of the entry to get audio URL for | |
| cache_neighbors: Whether to cache audio URLs for neighboring entries | |
| neighbor_range: Range of neighboring entries to cache | |
| Returns: | |
| Audio URL for the specified entry | |
| """ | |
| # Calculate the range of entries to load | |
| if cache_neighbors: | |
| start_idx = max(0, entry_idx - neighbor_range) | |
| max_entries = neighbor_range * 2 + 1 | |
| else: | |
| start_idx = entry_idx | |
| max_entries = 1 | |
| # Get audio URLs for the range of entries | |
| audio_urls = get_dataset_entries_audio_urls(dataset, start_idx, max_entries) | |
| if not audio_urls: | |
| return None | |
| # Cache the audio URLs | |
| for i, url in enumerate(audio_urls): | |
| idx = start_idx + i | |
| # Extract expiration time from URL if available | |
| expires = None | |
| if "expires=" in url: | |
| try: | |
| expires_param = url.split("expires=")[1].split("&")[0] | |
| expires = datetime.fromtimestamp(int(expires_param)) | |
| except (ValueError, IndexError): | |
| expires = None | |
| st.session_state.audio_cache[idx] = {"url": url, "expires": expires} | |
| # Return the URL for the requested entry | |
| relative_idx = entry_idx - start_idx | |
| if 0 <= relative_idx < len(audio_urls): | |
| return audio_urls[relative_idx] | |
| return None | |
| def get_cached_audio_url(entry_idx): | |
| """ | |
| Get audio URL from cache if available and not expired | |
| Args: | |
| entry_idx: Index of the entry to get audio URL for | |
| Returns: | |
| Audio URL if available in cache and not expired, None otherwise | |
| """ | |
| if entry_idx not in st.session_state.audio_cache: | |
| return None | |
| cache_entry = st.session_state.audio_cache[entry_idx] | |
| # Check if the URL is expired | |
| if cache_entry["expires"] and datetime.now() > cache_entry["expires"]: | |
| return None | |
| return cache_entry["url"] | |
| def main(): | |
| st.set_page_config( | |
| page_title="ASR Evaluation Visualizer", page_icon="🎤", layout="wide" | |
| ) | |
| # RTL some tags | |
| ltr_tag("textarea") | |
| # Check for URL parameter for preloading leaderboard results | |
| lb_result_file_param = st.query_params.get("lb_result_file") | |
| if ( | |
| lb_result_file_param | |
| and st.session_state.consumed_query_lb_file != lb_result_file_param | |
| ): | |
| st.session_state.consumed_query_lb_file = lb_result_file_param | |
| leaderboard_file_url = huggingface_hub.hf_hub_url( | |
| repo_id="ivrit-ai/hebrew-transcription-leaderboard", | |
| subfolder=f"results{Path(lb_result_file_param).parent}", | |
| filename=Path(lb_result_file_param).name, | |
| repo_type="space", | |
| ) | |
| on_file_upload() | |
| st.session_state.results_file = leaderboard_file_url | |
| if not has_api_token: | |
| st.warning("No Hugging Face API token found. Audio previews will not work.") | |
| st.title("ASR Evaluation Visualizer") | |
| # File uploader | |
| uploaded_file = st.file_uploader( | |
| "Upload evaluation results CSV", | |
| type=["csv"], | |
| on_change=reset_upload_state, | |
| key="uploaded_file", | |
| ) | |
| if st.session_state.consumed_query_lb_file is not None: | |
| clear_col1, clear_col2 = st.columns([10, 1], gap="large") | |
| with clear_col1: | |
| st.info( | |
| f"Loaded: {st.session_state.consumed_query_lb_file or uploaded_file}" | |
| ) | |
| with clear_col2: | |
| if st.button("Unload"): | |
| reset_upload_state() | |
| st.rerun() | |
| if uploaded_file is not None: | |
| st.session_state.results_file = uploaded_file | |
| if st.session_state.results_file is None: | |
| st.write("Or:") | |
| if st.button("Choose from leaderboard"): | |
| choose_input_file_from_leaderboard() | |
| if st.button("Enter Manually"): | |
| st.session_state.manual_mode = True | |
| reset_upload_state() | |
| st.rerun() | |
| if st.session_state.results_file is not None: | |
| uploaded_file = st.session_state.results_file | |
| # Load the data | |
| try: | |
| eval_results = read_results_csv(uploaded_file) | |
| st.session_state.total_entry_count = len(eval_results) | |
| with st.sidebar: | |
| # Toggle for calculating total metrics | |
| show_total_metrics = st.toggle("Show total metrics", value=False) | |
| if show_total_metrics: | |
| total_metrics = calculate_final_metrics(uploaded_file, eval_results) | |
| # Display total metrics in a nice format | |
| with st.container(): | |
| st.metric("WER", f"{total_metrics.wer * 100:.4f}%") | |
| st.table( | |
| { | |
| "Hits": total_metrics.hits, | |
| "Subs": total_metrics.substitutions, | |
| "Dels": total_metrics.deletions, | |
| "Insrt": total_metrics.insertions, | |
| } | |
| ) | |
| # Toggle for normalized vs raw text | |
| use_normalized = st.sidebar.toggle("Use normalized text", value=True) | |
| show_metadata = st.sidebar.toggle("Show entry metadata", value=False) | |
| visualize_subs = st.sidebar.toggle("List Substitutions", value=False) | |
| # Create sidebar for entry selection | |
| st.sidebar.header("Select Entry") | |
| # Add Next/Prev buttons at the top of the sidebar | |
| col1, col2 = st.sidebar.columns(2) | |
| # Define navigation functions | |
| def go_prev(): | |
| if st.session_state.selected_entry_idx > 0: | |
| st.session_state.selected_entry_idx -= 1 | |
| def go_next(): | |
| if st.session_state.selected_entry_idx < len(eval_results) - 1: | |
| st.session_state.selected_entry_idx += 1 | |
| # Add navigation buttons | |
| col1.button("← Prev", on_click=go_prev, use_container_width=True) | |
| col2.button("Next →", on_click=go_next, use_container_width=True) | |
| # Use a container for better styling | |
| entry_container = st.sidebar.container() | |
| with entry_container: | |
| page_navigation() | |
| st.write(f"Total entries: {st.session_state.total_entry_count}") | |
| # Create a data table with entries and their WER | |
| entries_data = [] | |
| for i in range(len(eval_results)): | |
| wer_value = eval_results.iloc[i].get("wer", 0) | |
| # Format WER as percentage | |
| wer_formatted = ( | |
| f"{wer_value*100:.2f}%" | |
| if isinstance(wer_value, (int, float)) | |
| else wer_value | |
| ) | |
| entries_data.append({"Entry": f"Entry #{i+1}", "WER": wer_formatted}) | |
| # Create a selection mechanism using radio buttons that look like a table | |
| st.sidebar.write("Select an entry") | |
| # Create a radio button for each entry, styled to look like a table row | |
| current_page_slice = get_current_page_slice() | |
| entry_container.radio( | |
| "Select an entry", | |
| options=list(range(len(eval_results))[current_page_slice]), | |
| format_func=lambda i: f"Entry #{i+1} ({entries_data[i]['WER']})", | |
| label_visibility="collapsed", | |
| key="selected_entry_idx", | |
| ) | |
| # Use the selected entry | |
| selected_entry = st.session_state.selected_entry_idx | |
| # Get the text columns based on the toggle | |
| if use_normalized and "norm_reference_text" in eval_results.columns: | |
| ref_col, hyp_col = "norm_reference_text", "norm_predicted_text" | |
| else: | |
| ref_col, hyp_col = "reference_text", "predicted_text" | |
| # Get the reference and hypothesis texts | |
| ref, hyp = eval_results.iloc[selected_entry][[ref_col, hyp_col]].values | |
| st.header("Visualization") | |
| # Check if the CSV file is from a known dataset | |
| dataset_name = None | |
| # If no dataset column, try to infer from filename | |
| if uploaded_file is not None: | |
| if isinstance(uploaded_file, str): | |
| filename_stem = Path(uploaded_file).stem | |
| else: | |
| filename_stem = Path(uploaded_file.name).stem | |
| dataset_name = filename_stem | |
| if not dataset_name and "dataset" in eval_results.columns: | |
| dataset_name = eval_results.iloc[selected_entry]["dataset"] | |
| # Get the known dataset if available | |
| known_dataset = get_known_dataset_by_output_name(dataset_name) | |
| # Display audio preview button if from a known dataset | |
| if known_dataset: | |
| # Check if we have the audio URL in cache | |
| audio_url = get_cached_audio_url(selected_entry) | |
| audio_preview_active = st.session_state.audio_preview_active.get( | |
| selected_entry, False | |
| ) | |
| preview_audio = False | |
| if not audio_preview_active: | |
| # Create a button to preview audio | |
| preview_audio = st.button("Preview Audio", key="preview_audio") | |
| if preview_audio or audio_url: | |
| st.session_state.audio_preview_active[selected_entry] = True | |
| with st_fixed_container( | |
| mode="sticky", position="top", border=True, margin=0 | |
| ): | |
| # If button clicked or we already have the URL, get/use the audio URL | |
| if not audio_url: | |
| with st.spinner("Loading audio..."): | |
| audio_url = get_audio_url_for_entry( | |
| known_dataset, selected_entry | |
| ) | |
| # Display the audio player in the sticky container at the top | |
| if audio_url: | |
| st.audio(audio_url) | |
| else: | |
| st.error("Failed to load audio for this entry.") | |
| # Display the visualization | |
| html = render_visualize_jiwer_result_html(ref, hyp) | |
| display_rtl(html) | |
| if show_metadata: | |
| # Display metadata | |
| st.header("Metadata") | |
| metadata_cols = [ | |
| "metadata_uuid", | |
| "model", | |
| "dataset", | |
| "dataset_split", | |
| "engine", | |
| ] | |
| metadata = eval_results.iloc[selected_entry][metadata_cols] | |
| # Create a DataFrame for better display | |
| metadata_df = pd.DataFrame( | |
| {"Field": metadata_cols, "Value": [str(v) for v in metadata.values]} | |
| ) | |
| st.table(metadata_df) | |
| if visualize_subs: | |
| visualize_substitutions(ref, hyp) | |
| # If we have audio URL, display it in the sticky container | |
| if "audio_url" in locals() and audio_url: | |
| pass # CSS is now applied globally | |
| except Exception as e: | |
| st.error(f"Error processing file: {str(e)}") | |
| elif st.session_state.manual_mode: | |
| st.info( | |
| "Please enter the evaluation results CSV file to visualize the results." | |
| ) | |
| render_manual_eval() | |
| else: | |
| st.info( | |
| "Please upload an evaluation results CSV file to visualize the results." | |
| ) | |
| st.markdown( | |
| """ | |
| ### Expected CSV Format | |
| The CSV should have the following columns: | |
| - id | |
| - reference_text | |
| - predicted_text | |
| - norm_reference_text | |
| - norm_predicted_text | |
| - wer | |
| - wil | |
| - substitutions | |
| - deletions | |
| - insertions | |
| - hits | |
| - metadata_uuid | |
| - model | |
| - dataset | |
| - dataset_split | |
| - engine | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| main() | |