Spaces:
Sleeping
Sleeping
| """ | |
| Smoke test for zero-shot inference pipeline | |
| Tests: | |
| 1. Data loading and preparation | |
| 2. Chronos 2 model loading | |
| 3. Inference on single border (7 days) | |
| 4. Output validation | |
| 5. Performance metrics | |
| """ | |
| import sys | |
| from pathlib import Path | |
| # Add src to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent / 'src')) | |
| from inference.data_fetcher import DataFetcher | |
| from inference.chronos_pipeline import ChronosForecaster | |
| from datetime import datetime, timedelta | |
| import torch | |
| import pandas as pd | |
| def main(): | |
| print("="*60) | |
| print("FBMC Chronos 2 Zero-Shot Inference - Smoke Test") | |
| print("="*60) | |
| # Step 1: Check environment | |
| print("\n[1] Checking environment...") | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") | |
| else: | |
| print("Running on CPU (inference will be slower)") | |
| # Step 2: Initialize DataFetcher | |
| print("\n[2] Initializing DataFetcher...") | |
| fetcher = DataFetcher( | |
| use_local=True, # Use local files for testing | |
| context_length=512 # Use 512 hours context | |
| ) | |
| # Step 3: Load data | |
| print("\n[3] Loading unified features...") | |
| fetcher.load_data() | |
| # Get available date range | |
| min_date, max_date = fetcher.get_available_dates() | |
| print(f"Available data: {min_date} to {max_date}") | |
| # Select forecast date (use last month as test) | |
| forecast_date = max_date - timedelta(days=30) | |
| print(f"Test forecast date: {forecast_date}") | |
| # Step 4: Prepare inference data (single border, 7 days) | |
| print("\n[4] Preparing inference data (1 border, 7 days)...") | |
| test_border = fetcher.target_borders[0] # Use first border | |
| print(f"Test border: {test_border}") | |
| context_df, future_df = fetcher.prepare_inference_data( | |
| forecast_date=forecast_date, | |
| prediction_length=168, # 7 days | |
| borders=[test_border] | |
| ) | |
| print(f"Context shape: {context_df.shape}") | |
| print(f"Future shape: {future_df.shape}") | |
| # Validate data | |
| print("\n[5] Validating prepared data...") | |
| assert 'timestamp' in context_df.columns, "Missing timestamp column" | |
| assert 'border' in context_df.columns, "Missing border column" | |
| assert 'target' in context_df.columns, "Missing target column" | |
| assert len(context_df) > 0, "Empty context data" | |
| assert len(future_df) > 0, "Empty future data" | |
| print("[+] Data validation passed!") | |
| # Check for NaN values | |
| context_nulls = context_df.isnull().sum().sum() | |
| future_nulls = future_df.isnull().sum().sum() | |
| print(f"Context NaN count: {context_nulls}") | |
| print(f"Future NaN count: {future_nulls}") | |
| if context_nulls > 0 or future_nulls > 0: | |
| print("[!] Warning: Data contains NaN values (will be handled by model)") | |
| # Step 6: Initialize Chronos 2 forecaster | |
| print("\n[6] Initializing Chronos 2 forecaster...") | |
| forecaster = ChronosForecaster( | |
| model_name="amazon/chronos-2-large", | |
| device="auto" # Will use GPU if available | |
| ) | |
| # Step 7: Load model | |
| print("\n[7] Loading Chronos 2 Large model...") | |
| print("(This may take a few minutes on first load)") | |
| forecaster.load_model() | |
| print("[+] Model loaded successfully!") | |
| # Step 8: Run inference | |
| print("\n[8] Running zero-shot inference...") | |
| print(f"Forecasting {test_border} for 7 days (168 hours)") | |
| forecasts = forecaster.predict_single_border( | |
| border=test_border, | |
| context_df=context_df, | |
| future_df=future_df, | |
| prediction_length=168, | |
| num_samples=100 # 100 samples for probabilistic forecast | |
| ) | |
| print(f"[+] Inference complete! Forecast shape: {forecasts.shape}") | |
| # Step 9: Validate forecasts | |
| print("\n[9] Validating forecasts...") | |
| assert len(forecasts) > 0, "Empty forecasts" | |
| assert 'timestamp' in forecasts.columns or forecasts.index.name == 'timestamp', "Missing timestamp" | |
| # Check for reasonable values | |
| if 'mean' in forecasts.columns: | |
| mean_forecast = forecasts['mean'] | |
| print(f"Forecast statistics:") | |
| print(f" Mean: {mean_forecast.mean():.2f} MW") | |
| print(f" Min: {mean_forecast.min():.2f} MW") | |
| print(f" Max: {mean_forecast.max():.2f} MW") | |
| print(f" Std: {mean_forecast.std():.2f} MW") | |
| # Sanity check: values should be reasonable for power capacity | |
| assert mean_forecast.min() >= 0, "Negative forecasts detected" | |
| assert mean_forecast.max() < 20000, "Unreasonably high forecasts" | |
| print("[+] Forecast validation passed!") | |
| # Step 10: Benchmark performance | |
| print("\n[10] Benchmarking inference performance...") | |
| metrics = forecaster.benchmark_inference( | |
| context_df=context_df, | |
| future_df=future_df, | |
| prediction_length=168 | |
| ) | |
| print(f"Performance metrics:") | |
| for key, value in metrics.items(): | |
| print(f" {key}: {value}") | |
| # Check if we meet the 5-minute target (for 14 days) | |
| # Scale to 14-day estimate | |
| estimated_14d_time = metrics['inference_time_sec'] * (336 / 168) | |
| print(f"\nEstimated time for 14-day forecast: {estimated_14d_time:.1f}s ({estimated_14d_time/60:.1f} min)") | |
| if estimated_14d_time < 300: # 5 minutes | |
| print("[+] Performance target met! (<5 min for 14 days)") | |
| else: | |
| print("[!] Warning: May not meet 5-minute target for 14 days") | |
| # Step 11: Save test forecasts | |
| print("\n[11] Saving test forecasts...") | |
| output_path = "data/evaluation/smoke_test_forecast.parquet" | |
| forecaster.save_forecasts(forecasts, output_path) | |
| print(f"[+] Saved to: {output_path}") | |
| # Summary | |
| print("\n" + "="*60) | |
| print("SMOKE TEST SUMMARY") | |
| print("="*60) | |
| print("[+] All tests passed!") | |
| print(f"[+] Border: {test_border}") | |
| print(f"[+] Forecast length: 168 hours (7 days)") | |
| print(f"[+] Inference time: {metrics['inference_time_sec']:.1f}s") | |
| print(f"[+] Output shape: {forecasts.shape}") | |
| print("\n[+] Ready for full inference run!") | |
| print("="*60) | |
| if __name__ == "__main__": | |
| main() | |