Description
Hi Patrick! Congratulations on your research work on neural differential equations. It's quite impressive, and thank you for the torchcde
and Diffrax libraries.
I've been experimenting with the torchcde module for some time now. I've read the repository and related papers: https://arxiv.org/abs/2005.08926, https://arxiv.org/abs/2106.11028. Currently, I'm working on a time series prediction problem using neural CDEs. I will migrate it to Diffrax, but I have a question, and I think your experience can help me to address it.
In a nutshell, I'm predicting a substance concentration in blood, denoted as
I've tried several strategies:
-
Training a neuralCDE for each patient and validating it using a rolling window strategy. However, the out-of-sample predictions seem to be quite similar to the last value of
$Y$ observed, indicating possible overfitting. Moreover, thecoeffs
obtained from interpolation change their size across windows, which raises concerns about the approach's validation and effectiveness. -
Instead of using the window strategy, I considered replacing
t=X.interval
witht=X.grid_points
in thetorchcde.cdeint(...)
function (assuming that the hidden channel, dim=1, directly represents$Y$ ). This change would allow me to obtain an estimated array$\hat{Y}$ for all time steps considered, but the true value of$Y$ is recorded only at specific steps. Not sure about how to compute the loss function in this case. -
Another approach I considered is splitting
$X$ and$Y$ into one-step$Y$ -related measurements for all patients. For example, if$Y$ is available for patient A at$t_{a1}$ ,$t_{a2}$ ,$t_{a3}$ ..., I would divide$X$ ,$Y$ for patient A into batches$[t_{0}, t_{a1-1}], [t_{a1}, t_{a2-1}]$ , and so on. I would apply a similar strategy for patient B, and then group all batches from all patients to follow the irregular data strategy as commented in irregular_data.py. This approach would allow me to perform train/validation/test splits, ensuring that all sets have the samecoeffs
length and making testing more manageable. However, I'm concerned that with this strategy I'm losing information as predicting$Y$ at$t_{a2}$ would mean missing records before$t_{a1}$ that may be useful.
As you can see, it's a question related to preprocessing or train-test strategies, but with the way of input data for neuralCDEs, it might be worth thinking it over carefully. Any comments would be greatly appreciated. Thank you very much!