fbmc-chronos2 / src /inference /data_fetcher.py
Evgueni Poloukarov
feat: implement zero-shot inference pipeline for Day 3
44b73f4
"""
Data Fetcher for Zero-Shot Inference
Prepares data for Chronos 2 inference by:
1. Loading unified features from HuggingFace Dataset
2. Identifying future covariates from metadata
3. Preparing context window (historical data)
4. Preparing future covariates for forecast horizon
5. Formatting data for Chronos 2 predict_df() API
"""
from pathlib import Path
from typing import Tuple, List, Optional
import pandas as pd
import polars as pl
from datetime import datetime, timedelta
from datasets import load_dataset
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DataFetcher:
"""
Fetches and prepares data for zero-shot Chronos 2 inference.
Handles:
- Loading unified features (2,553 features)
- Identifying future covariates (615 features)
- Creating context windows for each border
- Extending future covariates into forecast horizon
"""
def __init__(
self,
dataset_name: str = "evgueni-p/fbmc-features-24month",
local_features_path: Optional[str] = None,
local_metadata_path: Optional[str] = None,
context_length: int = 512,
use_local: bool = False
):
"""
Initialize DataFetcher.
Args:
dataset_name: HuggingFace dataset name
local_features_path: Path to local features parquet file
local_metadata_path: Path to local metadata CSV
context_length: Number of hours to use as context (default: 512)
use_local: If True, load from local files instead of HF Dataset
"""
self.dataset_name = dataset_name
self.local_features_path = local_features_path or "data/processed/features_unified_24month.parquet"
self.local_metadata_path = local_metadata_path or "data/processed/features_unified_metadata.csv"
self.context_length = context_length
self.use_local = use_local
# Will be loaded lazily
self.features_df: Optional[pl.DataFrame] = None
self.metadata_df: Optional[pd.DataFrame] = None
self.future_covariate_cols: Optional[List[str]] = None
self.target_borders: Optional[List[str]] = None
def load_data(self):
"""Load unified features and metadata."""
logger.info("Loading unified features and metadata...")
if self.use_local:
# Load from local files
logger.info(f"Loading features from: {self.local_features_path}")
self.features_df = pl.read_parquet(self.local_features_path)
logger.info(f"Loading metadata from: {self.local_metadata_path}")
self.metadata_df = pd.read_csv(self.local_metadata_path)
else:
# Load from HuggingFace Dataset
logger.info(f"Loading features from HF Dataset: {self.dataset_name}")
dataset = load_dataset(self.dataset_name, split="train")
self.features_df = pl.from_pandas(dataset.to_pandas())
# Try to load metadata from HF Dataset
try:
metadata_dataset = load_dataset(self.dataset_name, data_files="metadata.csv", split="train")
self.metadata_df = metadata_dataset.to_pandas()
except:
logger.warning("Could not load metadata from HF Dataset, falling back to local")
self.metadata_df = pd.read_csv(self.local_metadata_path)
# Ensure timestamp column is datetime
if 'timestamp' in self.features_df.columns:
self.features_df = self.features_df.with_columns(
pl.col('timestamp').str.to_datetime()
)
logger.info(f"Loaded {len(self.features_df)} rows, {len(self.features_df.columns)} columns")
logger.info(f"Date range: {self.features_df['timestamp'].min()} to {self.features_df['timestamp'].max()}")
# Identify future covariates
self._identify_future_covariates()
# Identify target borders
self._identify_target_borders()
def _identify_future_covariates(self):
"""Identify columns that are future covariates from metadata."""
logger.info("Identifying future covariates from metadata...")
# Filter for future covariates
future_cov_meta = self.metadata_df[
self.metadata_df['is_future_covariate'] == True
]
self.future_covariate_cols = future_cov_meta['feature_name'].tolist()
logger.info(f"Found {len(self.future_covariate_cols)} future covariates")
logger.info(f"Categories: {future_cov_meta['category'].value_counts().to_dict()}")
def _identify_target_borders(self):
"""Identify target borders from NTC columns."""
logger.info("Identifying target borders...")
# Find all ntc_actual_* columns
ntc_cols = [col for col in self.features_df.columns if col.startswith('ntc_actual_')]
# Extract border names
self.target_borders = [col.replace('ntc_actual_', '') for col in ntc_cols]
logger.info(f"Found {len(self.target_borders)} target borders")
logger.info(f"Borders: {', '.join(self.target_borders[:5])}...")
def prepare_inference_data(
self,
forecast_date: datetime,
prediction_length: int = 336, # 14 days
borders: Optional[List[str]] = None
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Prepare context and future data for Chronos 2 inference.
Args:
forecast_date: The date to forecast from (as-of date)
prediction_length: Number of hours to forecast (default: 336 = 14 days)
borders: List of borders to forecast (default: all borders)
Returns:
context_df: Historical data (timestamp, border, target, all features)
future_df: Future covariates (timestamp, border, future covariates only)
"""
if self.features_df is None:
self.load_data()
borders = borders or self.target_borders
logger.info(f"Preparing inference data for {len(borders)} borders")
logger.info(f"Forecast date: {forecast_date}")
logger.info(f"Context length: {self.context_length} hours")
logger.info(f"Prediction length: {prediction_length} hours")
# Extract context window (historical data)
context_start = forecast_date - timedelta(hours=self.context_length)
context_df = self.features_df.filter(
(pl.col('timestamp') >= context_start) &
(pl.col('timestamp') < forecast_date)
)
logger.info(f"Context window: {context_df['timestamp'].min()} to {context_df['timestamp'].max()}")
logger.info(f"Context rows: {len(context_df)}")
# Prepare context data for each border
context_dfs = []
for border in borders:
ntc_col = f'ntc_actual_{border}'
if ntc_col not in context_df.columns:
logger.warning(f"Border {border} not found in features, skipping")
continue
# Select: timestamp, target, all features
border_context = context_df.select([
'timestamp',
pl.lit(border).alias('border'),
pl.col(ntc_col).alias('target'),
*[col for col in context_df.columns if col not in ['timestamp', ntc_col]]
])
context_dfs.append(border_context)
# Combine all borders
context_combined = pl.concat(context_dfs)
logger.info(f"Combined context shape: {context_combined.shape}")
# Prepare future covariates
# For MVP: Use last known values or simple forward-fill
# TODO: In production, fetch fresh weather forecasts, generate temporal features
logger.info("Preparing future covariates...")
future_dfs = []
for border in borders:
# Create future timestamps
future_timestamps = pd.date_range(
start=forecast_date,
periods=prediction_length,
freq='H'
)
# Get last known values of future covariates
last_row = context_df.filter(pl.col('timestamp') == context_df['timestamp'].max())
# Extract future covariate values
future_values = last_row.select(self.future_covariate_cols)
# Repeat for all future timestamps
future_border_df = pl.DataFrame({
'timestamp': future_timestamps,
'border': [border] * len(future_timestamps)
})
# Add future covariate values (forward-fill from last known)
for col in self.future_covariate_cols:
if col in future_values.columns:
value = future_values[col][0]
future_border_df = future_border_df.with_columns(
pl.lit(value).alias(col)
)
future_dfs.append(future_border_df)
# Combine all borders
future_combined = pl.concat(future_dfs)
logger.info(f"Future covariates shape: {future_combined.shape}")
# Convert to pandas for Chronos 2
context_pd = context_combined.to_pandas()
future_pd = future_combined.to_pandas()
logger.info("Data preparation complete!")
logger.info(f"Context: {context_pd.shape}, Future: {future_pd.shape}")
return context_pd, future_pd
def get_available_dates(self) -> Tuple[datetime, datetime]:
"""Get the available date range in the dataset."""
if self.features_df is None:
self.load_data()
min_date = self.features_df['timestamp'].min()
max_date = self.features_df['timestamp'].max()
return min_date, max_date