-
Notifications
You must be signed in to change notification settings - Fork 1
Train models on TPUs with Pytorch/XLA #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@jqhoogland please note in particular I did not get around to testing the TPU RNG for the first full sweep. If the runs were not deterministic as a result, this has two consequences about the data from this first run:
I will look into the deterministic TPU stuff when I get a chance. I don't know how to check the checkpoints for clashes. |
With the checkpointer it just overwrites any files, so the checkpoints should all just be from the most recent run. |
It might be the case that multiple runs for the same config were happening at roughly the same time such that it is not clear which of the wandb runs the checkpoints will be for. |
On path 6 (paralellism within TPUs, by device)Turns out this was trivially straightforward with the right knowledge.
To spell this out, here's how to run four training runs in parallel. The basic principle is to set some environment variables that configure the TPU into a mode where it keeps the four devices separate, and then further variables to select one of the four devices. The environment variables to use are as follows:
So suppose the usual training command was as follows:
Then you could run the following four commands in four different shells to do this four times in parallel (scroll right to see the differences in the ports and device number made visible):
The final step would be to find an easy way to set these variables for an agent without having to copy/paste them every time (tedious and error-prone). Edit to add: Spell out these commands once in a script with custom output redirection, nohup, and backgrounding, and then these 4 lines can just be pasted into a single SSH session for each TPU VM. That seems easy enough. Seems to work (currently running 40 experiments at once). Will add to the guide and call it done. |
On step 2 (Deterministic randomness)XLA does appear to have its own RNG with a seed function. However, when I tested some of the tensor values in the first iteration of a training run, I noticed they were already the same, even though we haven't seeded it. This is confusing and we should investigate why it is happening to ensure our runs are reproducible.
A consequence is that possibly the runs we did in the past were already seeded. We could confirm that, but going forward, we should aim to resolve the above confusion and, if necessary, seed the XLA RNG, to ensure deterministic computation. If we do need to seed the XLA RNG manually, then here's how I think that would work:
|
On 5 (Parallelisation across TPUs)W&B sweeps is the right tool for this. The instructions are now updated to reflect how to do this. The command to allow logging out after launching an agent would be:
This seems to work, and any further issues with these agents should be noted separately. |
On 4 (Basic optimisations)I modified the training loop to work with both XLA and without. Summarised here. First, we only need to import XLA libraries if the device is XLA. I propose configuring with a string ('xla') and checking at the start of training if we want to use XLA. I had tried to modify # special code if device is 'xla'
XLA = (config.device == 'xla')
if XLA:
stdlogger.info("device is 'xla'! some special code will run...")
stdlogger.info("importing torch_xla...")
import torch_xla.core.xla_model as xm
# import torch_xla.debug.metrics as met
stdlogger.info("configuring default XLA device...")
device = xm.xla_device()
stdlogger.info("xla ready!")
else:
device = config.device Then we initialise the model and data as usual (except, using # model initialisation
stdlogger.info("initialising model")
model = config.task_config.model_factory().to(config.device)
model.train()
# initialise 'pretraining' data source (for training on fixed task set)
stdlogger.info("initialising data (pretrain)")
pretrain_dist = config.task_config.pretrain_dist_factory().to(config.device)
# initialise 'true' data source (for evaluation, including unseen tasks)
stdlogger.info("initialising data (true)")
true_dist = config.task_config.true_dist_factory().to(config.device) The evaluator involves running some code (the baselines) on this device, so we need to use # initialise evaluations
stdlogger.info("initialising evaluator")
if XLA: xm.mark_step()
evaluator = ICLEvaluator(
pretrain_dist=pretrain_dist,
true_dist=true_dist,
max_examples=config.task_config.max_examples,
eval_batch_size=config.eval_batch_size,
seed=config.task_config.true_seed
)
if XLA: xm.mark_step() Initialise the monitoring code and optimisers as usual (this shouldn't require XLA?) # initialise monitoring code
stdlogger.info("initialising checkpointer and logger")
checkpointer = config.checkpointer_config.factory() if config.checkpointer_config is not None else None
logger = config.logger_config.factory() if config.logger_config is not None else None
# initialise torch optimiser
stdlogger.info("initialising optimiser and scheduler")
optimizer = config.optimizer_config.factory(model.parameters())
scheduler = config.scheduler_config.factory(optimizer) # type: ignore (Actually, come to think of it, model.parameters() is on the XLA device, TODO later: see if a mark step after that helps?) There was some code to log recent zeros, however I think this might affect the computational graph because of the dependence on 'step'? I don't know, worth testing, for now I have disabled it. # TODO: this is unused and may be slowing down XLA... use it or lose it
# recent_losses = torch.zeros(100, device=config.device) Now the training loop! # training loop
stdlogger.info("starting training loop")
stdlogger.info("note: first two iterations slow while XLA compiles")
stdlogger.info("note: early iterations slow due to logspace checkpoints")
for step in tqdm.trange(config.num_steps, desc="training..."):
# per-step seeds for reproducibility if we resume training
set_seed(config.task_config.sampling_seed + step) The first thing inside the training loop is the training step itself. I bounded this by mark steps to make sure it's really isolated. # training step
if XLA: xm.mark_step()
xs, ys = pretrain_dist.get_batch(
num_examples=config.task_config.max_examples,
batch_size=config.batch_size,
)
ys_pred = model(xs, ys)
loss = F.mse_loss(ys, ys_pred)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
scheduler.step()
if XLA: xm.mark_step() More recent losses stuff, commented out: # see above
# recent_losses[step % 100] = loss Logging the batch pulls # wand logging: log batch loss every 100 steps
if step % 100 == 0 and step > 0 and config.is_wandb_enabled:
stdlogger.info("logging batch loss at step %s", step)
# TODO: Figure out how to make this work with `logger`
wandb.log({"batch/loss": loss.mean().item()}, step=step) Every now and then we run the evaluations---definitely mark that to compile it separately. # evaluate and log metrics to wandb according to log_steps
if step in config.logger_config.logging_steps:
stdlogger.info("evaluating metrics at step %s", step)
if XLA: xm.mark_step()
model.eval()
metrics = evaluator(model)
model.train()
if XLA: xm.mark_step()
stdlogger.info("logging metrics at step %s", step)
logger.log(metrics, step=step) And finally checkpointing. Here's where I suggest we should move the model off the TPU (to deal with part 3 of this issue). I don't know if that requires marking, probably not? # save checkpoints according to checkpoint_steps
if step in config.checkpointer_config.checkpoint_steps:
# TODO: if xla: move model to CPU before saving
stdlogger.info("saving checkpoint at step %s", step)
if XLA: xm.mark_step()
checkpointer.save_file(step, state_dict(model, optimizer, scheduler))
if XLA: xm.mark_step() That's all folks. if config.is_wandb_enabled:
wandb.finish() Probably we should not return a model on the TPU device to code that probably doesn't use TPU. # TODO: if XLA, move model off TPU?
return model |
^Oh yeah, it runs fast like this, at least faster than CPU, but unsure if optimal. In particular, this might be more than necessary number of mark steps and I don't know if they have overhead. Leaving that for later (see path 9 in top post in this issue). |
We would like to run our experiments on TPUs. To get this to happen we need to use Pytorch/XLA. This involves at least the following basics:
Python 3.8: Pytorch/XLA runs only on Python 3.8. A small number of our dependencies and code requires Python 3.9+. We should identify and work around these. See also devinterp issue 11.
Deterministic randomness: Does XLA have its own TPU-based RNG? Are we seeding it? Look into this and make sure we are getting reproducible runs.
Checkpointing: we need to make sure the exported models are able to be loaded properly. We had some trouble with this so it seems we should first move them to CPU then save and store the checkpoint.
✅ Basic optimisation: Context: Pytorch/XLA tensors are computed lazily. Tensor operations construct a computational graph, and on demand (or explicit request) the computational graph is compiled (with an optimising compiler) and then executed on the TPU. The compilation step is very expensive, and only pays off if the same computational graph is used repeatedly (such as in each iteration of a training loop) where the compilation can be cached. So, to get baseline performance, we need to do the following:
That should lead to ~3x speed up once everything is working on the TPU.
Then there are some pathways to further optimisation that seem low-hanging enough to be worth exploring:
✅ Parallelisation across TPUs (10x speed up): Google TPU Research Cloud offers 5 x TPU v2 and 5 x TPU v3. So that's 10 TPUs that can be conducting independent training runs. The challenge here is to efficiently manage sweeps across 10 independent VMs.
✅ Parallelisation within TPUs: (up to 4x speed up): Each TPU v2-8 or v3-8 actually has four two-core chips (so-called 'devices') that can compute in parallel. In other words, so far we are only using 1/4 of each TPU. Possibilities for doing further parallelisation across the four chips:
Stretch goals:
Also use preemptible TPUs (11x speed up): Google TPU Research Cloud offers a further 100 free preemptible TPU v2 (Dan clarifies that 'preemptible' means each VM can be killed at any point, lasting up to 24 hours I think, after which point I assume we can spawn new ones).
These improvements will also be useful for running experiments on non-preemptible VMs, which also sometimes need to be respawned or training resumed from a checkpoint after a crash.
More optimisation (uncertain small speed ups): Beyond just 'getting the TPU to run faster than the CPU' for steps (4), there is potentially more room to speed up each training run:
The text was updated successfully, but these errors were encountered: