- Install the requirements:
pip install -r requirements.txt
- Edit the config.yaml (see below)
- Run the main script:
python main.py config.yaml
- Navigate to the (newly created) save folder, and analyze the results in
runs/
, e.g. rerun the main script
Several configuration files corresponding to the experiments presented in the paper can be found in
cfgs/
This section outlines the configuration parameters used for training and evaluating the models.
Parameter | Description |
---|---|
seed |
Random seed for reproducibility. |
train |
Whether to train the model or not. |
dataset |
Dataset to use (mnist, celeba, uea, etc.). |
data_folder |
Path to the data folder. |
classification |
Whether this is a classification or a forecasting task. |
Parameter | Description |
---|---|
resolution |
Resolution of input images (for CelebA). |
downsample_factor |
Resolution downsampling factor for MNIST. |
traj_length |
Unused ! |
normalize |
Whether to place the data in the [-1,1]. |
Parameter | Description |
---|---|
model_type |
'wsm' for WARP, 'lstm', or 'gru' |
root_hidden_size |
Size of hidden layers in the root network. |
root_depth |
Number of layers in the root network. |
root_activation |
Activation function for the root MLP. |
root_final_activation |
Final activation function for the root MLP's output mean. |
std_lower_bound |
Lower bound for clipping the standard deviation. |
nb_rnn_layers |
Number of RNN layers (fixed at 1 for now !). |
init_state_layers |
Number of layers in the initial hypernetwork (null if sidestepping it completely). |
input_prev_data |
Whether the root network uses the previous observation as input. |
weights_lim |
Limit for the weights of the root model. |
time_as_channel |
Whether to time as an additional input channel. |
forcing_prob |
Probability of using the true input during teacher-forcing. |
noise_theta_init |
Whether to add noise the theta_0 at the start of the recurrence. |
Parameter | Description |
---|---|
init_lr |
Initial learning rate. |
gradient_lim |
Gradient limit for clipping. |
on_plateau |
Set of parameters for the 'reduce_on_pleatau' strategy in Optax. |
Parameter | Description |
---|---|
nb_epochs |
Number of training epochs. |
batch_size |
Batch size for training. |
print_every |
How often to print training progress. |
save_every |
How often to save artefacts, e.g. losses. |
valid_every |
How often to validate on the validation set |
val_criterion |
Metric for choosing the best model |
inference_start |
Length of context length for autoregressive digit generation. |
autoregressive |
Wether to train in AR mode or convolution mode |
stochastic |
Wether to use the reparametrization trick, or simply take the mean |
nb_recons_loss_steps |
Number of steps to sample for reconstruction loss. |
use_nll_loss |
Whether to use NLL or MSE loss for forecasting. |