Description
I'm running torch.compile
on a transformer-based model. When I run the compiled function, I don't get any recompiles for subsequent calls in the same script, for example:
m = torch.compile(self.transformer.forward, fullgraph=True, dynamic=True, mode="max-autotune", disable=False)
m(...)
m(...)
works fine, with no recompiles. But when I run the script again, the torch.compile
warmup takes a reasonably long time (running twice each for batch sizes 1, 2, and 4) -- about 10 seconds per input batch size. I ran the script with TORCH_LOGS=+dynamo,+inductor,+recompiles,+aot
, and it seems like most of the time is spent tracing the code and running shape solving, even though the model code is unchanged.
Files are written to the disk cache, and the disk cache does speed up compile times, but I'm not sure why it still takes so long to trace and run the shape solver when the model code is unchanged, are the traced graphs not cached?
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @oulgen @jamesjwu @aorenste @anijain2305 @laithsakka @masnesral