-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
[ENH] de-novo implementation of LTSFTransformer
based on cure-lab
research code base
#6202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Nice! Quick question, is this lifting the code from somewhere (in which case we need togive credit in docstrings etc), or is it a de-novo implementation? |
Originally copied then edited for sktime compatibility. |
I need to implement the def forward(self, x):
"""Forward pass for LSTF-Transformer Network.
Parameters
----------
x : torch.Tensor
torch.Tensor of shape [Batch, Input Sequence Length, Channel]
Returns
-------
x : torch.Tensor
output of Linear Model. x.shape = [Batch, Output Length, Channel]
"""
from torch import ones
batch_size = x.size(0)
seq_len = self.seq_len
pred_len = self.pred_len
num_features = x.size(2)
num_X_features = 5
x_enc = x
x_mark_enc = ones(batch_size, seq_len, num_X_features)
x_dec = ones(batch_size, pred_len, num_features)
x_mark_dec = ones(batch_size, pred_len, num_X_features)
return self._forward(x_enc, x_mark_enc, x_dec, x_mark_dec)
def _forward(
self,
x_enc,
x_mark_enc,
x_dec,
x_mark_dec,
enc_self_mask=None,
dec_self_mask=None,
dec_enc_mask=None,
):
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
dec_out = self.dec_embedding(x_dec, x_mark_dec)
dec_out = self.decoder(
dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask
)
if self.output_attention:
return dec_out[:, -self.pred_len :, :], attns
else:
return dec_out[:, -self.pred_len :, :] # [B, L, D] Code above shows the
|
Summarizing above ... Problem Statementsktime interface provides
Now we have to make changes in the pytorch adapter or the forward pass that translates the Proposed Design
These changes are needed to be made on the Requested Feedback@benHeid |
I would propose to take a look into their experiment file exp/exp_main.py and in their dataset implementation. In the dataset classes the separation in the y and x is done. Afterwards, the masking for prediction etc is done in the main_exp file. So Regarding Q1: I would propose to use exactly their approach but I assume that your is quite similar. Regarding Q2, y would be a concatenation from the historical values and zeros if I understood their implementation correctly. Please ping me if you would like clarification at some point. |
@benHeid I think we cannot adopt the |
1bb9b95
to
600d189
Compare
@fkiraly this is how the docstring looks like at the moment
|
@benHeid I have tried to train and predict using the current implementation (no changes in the architecture) and the predictions don't seem to move at all. Should I inquire more into this by trying other loss and optimizer methods or keeping working on the code and check this later? |
I belive the dataset class is complete and ready to review. I have made some changes in the original dataset class that was provided by the cure-lab. These are only syntax changes, no change in core logic, although I have reduced some code that was not useful in the interface. Changes made
Todos
By the way this is how the dataloader looks like
|
Following changes are made to reformat init params
|
fa52cda
to
9f9193d
Compare
I agree |
LTSFTransformer
LTSFTransformer
based on cure-lab
research code base
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please resolve merge conflicts and ensure tests run, @geetu040?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
One question, I see the test test_predict_time_index_in_sample_full
is skipped - why is that skipped?
Because |
What you describe is that the forecaster cannot make insample forecasts - this should be addressed by setting the insample tag correctly, and then the test should work: |
|
it will check that the tag is correctly set, I believe |
yes, it checks via the tag - updated the tests config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, looks like we did not have to skip those tests after all!
Would appreciate a review from @benHeid before merging due to familiarity with the algorithm, this is a complex PR.
Reference Issues/PRs
Implements
LTSFTransformer
from #4939What does this implement/fix? Explain your changes.
New forecaster
LTSFTransformer
Does your contribution introduce a new dependency? If yes, which one?
No
What should a reviewer concentrate their feedback on?
Did you add any tests for the change?
Not yet
PR checklist
For all contributions
maintaners
tag - do this if you want to become the owner or maintainer of an estimator you added.See here for further details on the algorithm maintainer role.
For new estimators
docs/source/api_reference/taskname.rst
, follow the pattern.Examples
section.