fbmc-chronos2 / scripts /test_inference_pipeline.py
Evgueni Poloukarov
feat: implement zero-shot inference pipeline for Day 3
44b73f4
raw
history blame
6.2 kB
"""
Smoke test for zero-shot inference pipeline
Tests:
1. Data loading and preparation
2. Chronos 2 model loading
3. Inference on single border (7 days)
4. Output validation
5. Performance metrics
"""
import sys
from pathlib import Path
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
from inference.data_fetcher import DataFetcher
from inference.chronos_pipeline import ChronosForecaster
from datetime import datetime, timedelta
import torch
import pandas as pd
def main():
print("="*60)
print("FBMC Chronos 2 Zero-Shot Inference - Smoke Test")
print("="*60)
# Step 1: Check environment
print("\n[1] Checking environment...")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
print("Running on CPU (inference will be slower)")
# Step 2: Initialize DataFetcher
print("\n[2] Initializing DataFetcher...")
fetcher = DataFetcher(
use_local=True, # Use local files for testing
context_length=512 # Use 512 hours context
)
# Step 3: Load data
print("\n[3] Loading unified features...")
fetcher.load_data()
# Get available date range
min_date, max_date = fetcher.get_available_dates()
print(f"Available data: {min_date} to {max_date}")
# Select forecast date (use last month as test)
forecast_date = max_date - timedelta(days=30)
print(f"Test forecast date: {forecast_date}")
# Step 4: Prepare inference data (single border, 7 days)
print("\n[4] Preparing inference data (1 border, 7 days)...")
test_border = fetcher.target_borders[0] # Use first border
print(f"Test border: {test_border}")
context_df, future_df = fetcher.prepare_inference_data(
forecast_date=forecast_date,
prediction_length=168, # 7 days
borders=[test_border]
)
print(f"Context shape: {context_df.shape}")
print(f"Future shape: {future_df.shape}")
# Validate data
print("\n[5] Validating prepared data...")
assert 'timestamp' in context_df.columns, "Missing timestamp column"
assert 'border' in context_df.columns, "Missing border column"
assert 'target' in context_df.columns, "Missing target column"
assert len(context_df) > 0, "Empty context data"
assert len(future_df) > 0, "Empty future data"
print("[+] Data validation passed!")
# Check for NaN values
context_nulls = context_df.isnull().sum().sum()
future_nulls = future_df.isnull().sum().sum()
print(f"Context NaN count: {context_nulls}")
print(f"Future NaN count: {future_nulls}")
if context_nulls > 0 or future_nulls > 0:
print("[!] Warning: Data contains NaN values (will be handled by model)")
# Step 6: Initialize Chronos 2 forecaster
print("\n[6] Initializing Chronos 2 forecaster...")
forecaster = ChronosForecaster(
model_name="amazon/chronos-2-large",
device="auto" # Will use GPU if available
)
# Step 7: Load model
print("\n[7] Loading Chronos 2 Large model...")
print("(This may take a few minutes on first load)")
forecaster.load_model()
print("[+] Model loaded successfully!")
# Step 8: Run inference
print("\n[8] Running zero-shot inference...")
print(f"Forecasting {test_border} for 7 days (168 hours)")
forecasts = forecaster.predict_single_border(
border=test_border,
context_df=context_df,
future_df=future_df,
prediction_length=168,
num_samples=100 # 100 samples for probabilistic forecast
)
print(f"[+] Inference complete! Forecast shape: {forecasts.shape}")
# Step 9: Validate forecasts
print("\n[9] Validating forecasts...")
assert len(forecasts) > 0, "Empty forecasts"
assert 'timestamp' in forecasts.columns or forecasts.index.name == 'timestamp', "Missing timestamp"
# Check for reasonable values
if 'mean' in forecasts.columns:
mean_forecast = forecasts['mean']
print(f"Forecast statistics:")
print(f" Mean: {mean_forecast.mean():.2f} MW")
print(f" Min: {mean_forecast.min():.2f} MW")
print(f" Max: {mean_forecast.max():.2f} MW")
print(f" Std: {mean_forecast.std():.2f} MW")
# Sanity check: values should be reasonable for power capacity
assert mean_forecast.min() >= 0, "Negative forecasts detected"
assert mean_forecast.max() < 20000, "Unreasonably high forecasts"
print("[+] Forecast validation passed!")
# Step 10: Benchmark performance
print("\n[10] Benchmarking inference performance...")
metrics = forecaster.benchmark_inference(
context_df=context_df,
future_df=future_df,
prediction_length=168
)
print(f"Performance metrics:")
for key, value in metrics.items():
print(f" {key}: {value}")
# Check if we meet the 5-minute target (for 14 days)
# Scale to 14-day estimate
estimated_14d_time = metrics['inference_time_sec'] * (336 / 168)
print(f"\nEstimated time for 14-day forecast: {estimated_14d_time:.1f}s ({estimated_14d_time/60:.1f} min)")
if estimated_14d_time < 300: # 5 minutes
print("[+] Performance target met! (<5 min for 14 days)")
else:
print("[!] Warning: May not meet 5-minute target for 14 days")
# Step 11: Save test forecasts
print("\n[11] Saving test forecasts...")
output_path = "data/evaluation/smoke_test_forecast.parquet"
forecaster.save_forecasts(forecasts, output_path)
print(f"[+] Saved to: {output_path}")
# Summary
print("\n" + "="*60)
print("SMOKE TEST SUMMARY")
print("="*60)
print("[+] All tests passed!")
print(f"[+] Border: {test_border}")
print(f"[+] Forecast length: 168 hours (7 days)")
print(f"[+] Inference time: {metrics['inference_time_sec']:.1f}s")
print(f"[+] Output shape: {forecasts.shape}")
print("\n[+] Ready for full inference run!")
print("="*60)
if __name__ == "__main__":
main()