Source code for commonpower.data_forecasting.nn_forecasting.dataset_wrappers

"""
Dataset wrappers for neural network forecasting.
"""
from __future__ import annotations

import torch
from torch.utils.data import Dataset

from commonpower.data_forecasting.base import DataSource
from commonpower.data_forecasting.nn_forecasting.data_splitting import DatasetSplit
from commonpower.data_forecasting.nn_forecasting.models import NNModule
from commonpower.data_forecasting.nn_forecasting.transform import IdentityTransform, Transformation


[docs] class DatasetWrapper(Dataset): def __init__( self, data_source: DataSource, data_split: DatasetSplit, model: NNModule, targets: list[str], features: list[str], feature_transform: Transformation = IdentityTransform(), target_transform: Transformation = IdentityTransform(), ) -> DatasetWrapper: """ DatasetWrappers determine how the train/val datasets are constructed. It acts as a wrapper around the data source and can be passed to a data loader as dataset. Args: data_source (DataSource): The data source for the dataset. data_split (DatasetSplit): The split of the dataset. model (NNModule): The neural network model. targets (list[str]): The list of target variables. features (list[str]): The list of feature variables. feature_transform (Transformation, optional): The transformation to apply to the features. Defaults to IdentityTransform(). target_transform (Transformation, optional): The transformation to apply to the targets. Defaults to IdentityTransform(). Returns: DatasetWrapper: The initialized DatasetWrapper object. """ self.data_source = data_source self.model = model self.data_split = data_split self.feature_transform = feature_transform self.target_transform = target_transform self.feature_idxs = [i for i, var in enumerate(self.data_source.get_variables()) if var in features] self.target_idxs = [i for i, var in enumerate(self.data_source.get_variables()) if var in targets] return self def __len__(self) -> int: """ Returns the number of elements in the dataset. """ return len(self.data_split) def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ Returns the data point at the given index. Args: idx (int): Index of the data point. Returns: tuple[torch.Tensor, torch.Tensor]: (input, target) """ return self._get_item(self.data_split.adjust_index(idx)) def _get_item(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError
[docs] class NStepAhead(DatasetWrapper): def __init__( self, data_source: DataSource, data_split: DatasetSplit, model: NNModule, targets: list[str], features: list[str], feature_transform: Transformation = IdentityTransform(), target_transform: Transformation = IdentityTransform(), ) -> NStepAhead: """ The NStepAhead wrapper inspects the passed model instance and determines the look back and steps ahead values from the model input and output shapes. The dataset is then constructed such that each data point has input dimension (n_look_back, n_features) and target dimension (n_steps_ahead, n_targets). Args: data_source (DataSource): The data source for the dataset. data_split (DatasetSplit): The split of the dataset. model (NNModule): The neural network model. targets (list[str]): The list of target variables. features (list[str]): The list of feature variables. feature_transform (Transformation, optional): The transformation to apply to the features. Defaults to IdentityTransform(). target_transform (Transformation, optional): The transformation to apply to the targets. Defaults to IdentityTransform(). Returns: NStepAhead: The initialized NStepAhead object. """ super().__init__(data_source, data_split, model, targets, features, feature_transform, target_transform) self.steps_ahead = model.output_shape[0] self.look_back = model.input_shape[0] - 1 # we are also considering the current time step in the model input
[docs] def _get_item(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ Returns the data point at the given index. Args: idx (int): Index of the data point. Returns: tuple[torch.Tensor, torch.Tensor]: (input, target) """ idx_time = self.data_source.get_date_range()[0] + idx * self.data_source.frequency feature_start_time = idx_time - self.look_back * self.data_source.frequency target_start_time = idx_time + self.data_source.frequency target_end_time = idx_time + self.steps_ahead * self.data_source.frequency input_data = self.data_source(feature_start_time, idx_time)[:, self.feature_idxs] target_data = self.data_source(target_start_time, target_end_time)[:, self.target_idxs] input_data_transformed = self.feature_transform(input_data) target_data_transformed = self.target_transform(target_data) return torch.tensor(input_data_transformed).float(), torch.tensor(target_data_transformed).float()