8000 [ENH] de-novo implementation of `LTSFTransformer` based on `cure-lab` research code base by geetu040 · Pull Request #6202 · sktime/sktime · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[ENH] de-novo implementation of LTSFTransformer based on cure-lab research code base #6202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Aug 6, 2024

Conversation

geetu040
Copy link
Contributor
@geetu040 geetu040 commented Mar 24, 2024

Reference Issues/PRs

Implements LTSFTransformer from #4939

What does this implement/fix? Explain your changes.

New forecaster LTSFTransformer

Does your contribution introduce a new dependency? If yes, which one?

No

What should a reviewer concentrate their feedback on?

  • Class names and API layout

Did you add any tests for the change?

Not yet

PR checklist

For all contributions
  • Optionally, for added estimators: I've added myself and possibly to the maintaners tag - do this if you want to become the owner or maintainer of an estimator you added.
    See here for further details on the algorithm maintainer role.
  • The PR title starts with either [ENH], [MNT], [DOC], or [BUG]. [BUG] - bugfix, [MNT] - CI, test framework, [ENH] - adding or improving code, [DOC] - writing or improving documentation or docstrings.
For new estimators
  • I've added the estimator to the API reference - in docs/source/api_reference/taskname.rst, follow the pattern.
  • I've added one or more illustrative usage examples to the docstring, in a pydocstyle compliant Examples section.
8000

@fkiraly
Copy link
Collaborator
fkiraly commented Mar 24, 2024

Nice! Quick question, is this lifting the code from somewhere (in which case we need togive credit in docstrings etc), or is it a de-novo implementation?

@fkiraly fkiraly added module:forecasting forecasting module: forecasting, incl probabilistic and hierarchical forecasting enhancement Adding new functionality labels Mar 24, 2024
@geetu040
Copy link
Contributor Author

Nice! Quick question, is this lifting the code from somewhere (in which case we need togive credit in docstrings etc), or is it a de-novo implementation?

Originally copied then edited for sktime compatibility.
I'll work on docstrings and credit original author, soon as the blocks are connected and interface is ready.

@geetu040
Copy link
Contributor Author

I need to implement the forward pass that aligns with sktime interface and transformer architechure. @fkiraly @benHeid your input here would be really appreciated.

        def forward(self, x):
            """Forward pass for LSTF-Transformer Network.

            Parameters
            ----------
            x : torch.Tensor
                torch.Tensor of shape [Batch, Input Sequence Length, Channel]

            Returns
            -------
            x : torch.Tensor
                output of Linear Model. x.shape = [Batch, Output Length, Channel]
            """
            from torch import ones

            batch_size = x.size(0)
            seq_len = self.seq_len
            pred_len = self.pred_len
            num_features = x.size(2)
            num_X_features = 5

            x_enc = x
            x_mark_enc = ones(batch_size, seq_len, num_X_features)
            x_dec = ones(batch_size, pred_len, num_features)
            x_mark_dec = ones(batch_size, pred_len, num_X_features)

            return self._forward(x_enc, x_mark_enc, x_dec, x_mark_dec)

        def _forward(
            self,
            x_enc,
            x_mark_enc,
            x_dec,
            x_mark_dec,
            enc_self_mask=None,
            dec_self_mask=None,
            dec_enc_mask=None,
        ):
            enc_out = self.enc_embedding(x_enc, x_mark_enc)
            enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)

            dec_out = self.dec_embedding(x_dec, x_mark_dec)
            dec_out = self.decoder(
                dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask
            )

            if self.output_attention:
                return dec_out[:, -self.pred_len :, :], attns
            else:
                return dec_out[:, -self.pred_len :, :]  # [B, L, D]

Code above shows the _forward method used by the transformer and in forward I simplified the input to just make the pipeline running and fix other architecture issues.
These are the important parameters required in the forward pass of transformer

  1. x_enc: Input data for the encoder. Can be the historical y in sktime context
  2. x_mark_enc: Time embeddings (or positional embeddings) for the encoder input data. Can be the historical X in sktime context. what happens when there is no X?
  3. x_dec: Input data for the decoder. Can be the y_pred in context of sktime where it can be fh when predicting and y_pred while training - this would require to change the pytorch adapter.
  4. x_mark_dec: Time embeddings (or positional embeddings) for the decoder input data. can be X_pred but again what happens when there is no exogenous data?

@geetu040
Copy link
Contributor Author

Summarizing above ...

Problem Statement

sktime interface provides X and y for training and prediction. Transformer architecture consisting of encoder and decoder takes 4 inputs

  • x_enc: input sequence for the encoder
  • x_mark_enc: time embeddings (or positional embeddings) of input sequence for encoder
  • x_dec: target sequence for the decoder
  • x_mark_dec: time embeddings (or positional embeddings) of target sequence for decoder

Now we have to make changes in the pytorch adapter or the forward pass that translates the X and y to these 4 inputs understandable by the transformer

Proposed Design

  1. We break the X and y in train into segments of X_enc, X_dec and y_enc, y_dec respectively by a specific ratio
  2. then use
    • y_enc as x_enc - input sequence for the encoder
    • y_dec as x_dec - target sequence for the decoder
    • X_enc as x_mark_enc - time embeddings (or positional embeddings) of input sequence for encoder
    • X_dec as x_mark_dec - time embeddings (or positional embeddings) of target sequence for decoder

These changes are needed to be made on the BaseDeepNetworkPyTorch

Requested Feedback

@benHeid
Q1: Please review the above design. I would need your acknowledgment before implementing in code
Q2: How would this strategy be used for prediction? We have there X_pred and y_pred which can be feeded as encoder inputs but how do we choose the decoder inputs in transformer?

@benHeid
Copy link
Contributor
benHeid commented Jun 12, 2024

I would propose to take a look into their experiment file exp/exp_main.py and in their dataset implementation.
And I would try to follow their approach to create the dataset as closely as possible. Differences would be that we wouldn't use a read instead we would provide directly the data.

In the dataset classes the separation in the y and x is done. Afterwards, the masking for prediction etc is done in the main_exp file.

So Regarding Q1: I would propose to use exactly their approach but I assume that your is quite similar. Regarding Q2, y would be a concatenation from the historical values and zeros if I understood their implementation correctly.

Please ping me if you would like clarification at some point.

@geetu040
Copy link
Contributor Author

@benHeid I think we cannot adopt the PredDataset class from cure-lab as it uses the available test data for validation. In our case, we are having real-time data instead of already available prediction data.
Please let me know if I am right about this - and the proposed method would be to create a Dataset class of our own that uses X and y for encoder from some of the historical values seen during fit and X (provided in predict) and y (zeroes) for decoder.

@geetu040
Copy link
Contributor Author

@fkiraly this is how the docstring looks like at the moment
see in file


    Parameters
    ----------
    seq_len : int
        Length of the input sequence.
    pred_len : int
        Length of the prediction sequence.
    label_len : int, optional (default=2)
        Length of the label sequence.
    num_epochs : int, optional (default=16)
        Number of epochs for training.
    batch_size : int, optional (default=8)
        Size of the batch.
    in_channels : int, optional (default=1)
        Number of input channels.
    individual : bool, optional (default=False)
        Whether to use individual models for each series.
    criterion : str or callable, optional
        Loss function to use.
    criterion_kwargs : dict, optional
        Additional keyword arguments for the loss function.
    optimizer : str or callable, optional
        Optimizer to use.
    optimizer_kwargs : dict, optional
        Additional keyword arguments for the optimizer.
    lr : float, optional (default=0.001)
        Learning rate.
    custom_dataset_train : torch.utils.data.Dataset, optional
        Custom dataset for training.
    custom_dataset_pred : torch.utils.data.Dataset, optional
        Custom dataset for prediction.
    output_attention : bool, optional (default=False)
        Whether to output attention weights.
    embed_type : int, optional (default=0)
        Type of embedding to use.
    embed : str, optional (default="fixed")
        Type of embedding.
    enc_in : int, optional (default=7)
        Number of encoder input features.
    dec_in : int, optional (default=7)
        Number of decoder input features.
    d_model : int, optional (default=512)
        Dimension of the model.
    n_heads : int, optional (default=8)
        Number of attention heads.
    d_ff : int, optional (default=2048)
        Dimension of the feed-forward network.
    e_layers : int, optional (default=3)
        Number of encoder layers.
    d_layers : int, optional (default=2)
        Number of decoder layers.
    factor : int, optional (default=5)
        Factor for attention.
    dropout : float, optional (default=0.1)
        Dropout rate.
    activation : str, optional (default="relu")
        Activation function.
    c_out : int, optional (default=7)
        Number of output features.
    freq : str, optional (default="h")
        Frequency of the data.
    Examples
    --------
    >>> from sktime.forecasting.ltsf import LTSFTransfomer, LTSFLinearForecaster
    >>> from sktime.datasets import load_longley
    >>>
    >>> batch_size = 5
    >>> seq_len = 5
    >>> label_len = 2
    >>> pred_len = 3
    >>> num_features = 1
    >>>
    >>> y, X = load_longley()
    >>> split_point = len(y) - pred_len
    >>> X_train, X_test = X[:split_point], X[split_point:]
    >>> y_train, y_test = y[:split_point], y[split_point:]
    >>>
    >>> model = LTSFTransfomer(
    ... 	seq_len = seq_len,
    ... 	pred_len = pred_len,
    ... 	label_len = label_len,
    ... 	output_attention = False,
    ... 	embed_type = 0,
    ... 	embed = "fiixed",
    ... 	enc_in = num_features,
    ... 	dec_in = num_features,
    ... 	d_model = 512,
    ... 	n_heads = 8,
    ... 	d_ff = 2048,
    ... 	e_layers = 1,
    ... 	d_layers = 1,
    ... 	factor = 5,
    ... 	dropout = 0.1,
    ... 	activation = "relu",
    ... 	c_out = pred_len,
    ... 	freq = 'h',
    ... 	num_epochs=1,
    ... 	batch_size=batch_size,
    >>> )
    >>>
    >>> model.fit(y_train, X_train, fh=[1, 2, 3])
    >>> pred = model.predict(X=X_test)

@geetu040
Copy link
Contributor Author

@benHeid I have tried to train and predict using the current implementation (no changes in the architecture) and the predictions don't seem to move at all. Should I inquire more into this by trying other loss and optimizer methods or keeping working on the code and check this later?

@geetu040
Copy link
Contributor Author

I belive the dataset class is complete and ready to review. I have made some changes in the original dataset class that was provided by the cure-lab. These are only syntax changes, no change in core logic, although I have reduced some code that was not useful in the interface.


Changes made

  1. seq_len, label_len, pred_len are taken as parameter rather than a list of size - as it makes a better interface
  2. I am keeping the complete data instead of splitting that into train-test-val to make it compatible with the existing pytorch adapter
  3. removed scaling option - as it should be done at the adapter level
  4. removed the code that concatentes X features with the Y as it is against the existing pytorch adapter implementation and if this step is needed, it is to be adopted at the adapter level

Todos

  1. scale data at adapter level
  2. current model will not work if the provided data index is not of valid format i.e DatetimeIndex. this needs to be looked
  3. there is a lot of code that seems hard coded and highly reliant on pandas frequency - that needs to be changed before it breaks
  4. predictions are still not improving with the epochs, that also needs a deeper study of the implementation or configuring the hyper parameters

By the way this is how the dataloader looks like

 ========== Data ========== 
[[102.]
 [108.]
 [122.]
 [119.]
 [111.]
 [125.]
 [138.]
 [138.]
 [126.]
 [109.]
 [ 94.]
 [108.]
 [105.]
 [116.]
 [131.]
 [125.]
 [115.]
 [139.]
 [160.]
 [160.]]
 ========== Dataloader: x ========== 
{'x_enc': tensor([[102.],
        [108.],
        [122.],
        [119.],
        [111.]]), 'x_mark_enc': tensor([[ 1., 31.,  0.,  0.],
        [ 2., 28.,  0.,  0.],
        [ 3., 31.,  3.,  0.],
        [ 4., 30.,  5.,  0.],
        [ 5., 31.,  1.,  0.]]), 'x_dec': tensor([[119.],
        [111.],
        [  0.],
        [  0.],
        [  0.]]), 'x_mark_dec': tensor([[ 4., 30.,  5.,  0.],
        [ 5., 31.,  1.,  0.],
        [ 6., 30.,  3.,  0.],
        [ 7., 31.,  6.,  0.],
        [ 8., 31.,  2.,  0.]])}
 ========== Dataloader: y ========== 
tensor([[125.],
        [138.],
        [138.]])

@geetu040
Copy link
Contributor Author

Following changes are made to reformat init params

  1. label_len renamed to context_len, as this name gives more information about the parameter
  2. output_attention removed from parameter, as the user would not need to see the output attention generated during the process and it would break the pytorch_adapter
  3. enc_in, dec_in, c_out are removed from parameter and a new param num_features is added temporarily, which will later be removed as well. num_features should be infered from the data and enc_in, dec_in, c_out should be equal to that

@benHeid
Copy link
Contributor
benHeid commented Jul 1, 2024

Personally, I would just force-ignore temporal_encoding if the index is not temporal.

I agree

@geetu040 geetu040 marked this pull request as ready for review July 1, 2024 21:11
@fkiraly fkiraly changed the title [ENH] Implements LTSFTransformer [ENH] de-novo implementation of LTSFTransformer based on cure-lab research code base Jul 4, 2024
Copy link
Collaborator
@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please resolve merge conflicts and ensure tests run, @geetu040?

Copy link
Collaborator
@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

One question, I see the test test_predict_time_index_in_sample_full is skipped - why is that skipped?

@geetu040
Copy link
Contributor Author

One question, I see the test test_predict_time_index_in_sample_full is skipped - why is that skipped?

Because LTSFTransformer forecasts the next values and doesnot work on fh with negative values. and this test case checks for negative values in fh if I am right

@fkiraly
Copy link
Collaborator
fkiraly commented Jul 17, 2024

What you describe is that the forecaster cannot make insample forecasts - this should be addressed by setting the insample tag correctly, and then the test should work:
https://www.sktime.net/en/latest/api_reference/auto_generated/sktime.registry._tags.capability__insample.html#sktime.registry._tags.capability__insample

@geetu040
Copy link
Contributor Author

"capability:insample": False has been set in the parent class BaseDeepNetworkPyTorch of LTSFTransformer
I skipped test_predict_time_index_in_sample_full for LTSFTransformer as it was skipped for other LTSF algorithms like LTSFNLinearForecaster and I thought it was for insample predictions. If this test case is not checking for the insample predictions then what does it check?

@fkiraly
Copy link
Collaborator
fkiraly commented Jul 18, 2024

it will check that the tag is correctly set, I believe

@geetu040
Copy link
Contributor Author

it will check that the tag is correctly set, I believe

yes, it checks via the tag - updated the tests config

@geetu040 geetu040 requested a review from fkiraly July 24, 2024 11:27
fkiraly
fkiraly previously approved these changes Jul 24, 2024
Copy link
Collaborator
@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, looks like we did not have to skip those tests after all!

Would appreciate a review from @benHeid before merging due to familiarity with the algorithm, this is a complex PR.

@fkiraly fkiraly merged commit 6397e87 into sktime:main Aug 6, 2024
68 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Adding new functionality module:forecasting forecasting module: forecasting, incl probabilistic and hierarchical forecasting
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

3 participants
0