Description
It's a bit more than poor scaling. ML workloads generally regress in performance at 32 threads and degrade further at 64 threads (compared to 16 threads).
This was observed while benchmarking LLaMA 3.1 with vLLM, where we noticed that paged_attention at 32 threads runs slower than at 16, and even slower at 64.
Running the paged_attention benchmarks under perf
revealed that:
- vLLM uses the libgomp (version 8.5) bundled in the PyTorch wheel (built in the manylinux_2_28_aarch64 Docker image).
- 70% of total time is spent in
gomp_iter_dynamic_next
.
98.16% 0.00% python libc.so.6 [.] thread_start
|
---thread_start
|
--98.16%--start_thread
|
--97.94%--gomp_thread_start
|
|--90.08%--paged_attention_v1_impl<float,…>::call [._omp_fn.0]
| |
| |--74.07%--gomp_iter_dynamic_next
| |
| --7.00%--reduceValueBlock<float,…>::λ(int)
|
--7.55%--gomp_thread_start
|
|--4.17%--gomp_barrier_wait_end
|
--3.38%--gomp_team_barrier_wait_end
Interestingly, running with my local libgomp 11.4 (using LD_PRELOAD) yields out-of-the-box speedups at high thread counts, and the perf output looks more reasonable - only 1.22% in gomp_iter_dynamic_next
:
97.60% 0.00% python libc.so.6 [.] thread_start
|
---thread_start
|
--97.60%--start_thread
|
|--96.98%--gomp_thread_start
| |
| |--75.30%--paged_attention_v1_impl<float,…>::call [._omp_fn.0]
| | |
| | |--28.96%--reduceValueBlock<float,…>::λ(int)
| | |
| | |--1.22%--gomp_iter_dynamic_next
| | | |
| | | --1.15%--__aarch64_ldadd8_sync
| | |
| | |--1.08%--expf@@GLIBC_2.27
| | |
| | --0.95%--FP32Vec16::reduce_sum()
| |
| |--11.56%--gomp_barrier_wait_end
| |
| --10.10%--gomp_team_barrier_wait_end
|
--0.61%--blas_thread_server
This mirrors past regressions on AArch64 at high thread counts, previously attributed to suboptimal thread-throttling in Arm Compute Library.
To confirm, I ran the benchmark script attached below with: PyTorch-wheel libgomp (8.5) vs. local version of libgomp (11.4) (see From those benchmarks, we can see for example that the following benchmark:
python bench.py --model bert-base-uncased --mode eager --context_length 128 --dtype int8
- With the wheel’s libgomp, runs 3× slower at 64 threads vs. 16 threads.
- With the local libgomp (with
LD_PRELOAD
), performance is comparable between 16 and 64 threads (no speedup, but no slowdown).
While bert-base isn’t typically run on 64 threads, a 3× slowdown relative to 16 threads is unexpected.
How can this be addressed in PyTorch release/nightly wheels?
Unfortunately, the current manylinux_2_28_aarch64 image used to build AArch64 wheels (AlmaLinux 8) only provides libgomp 8.5—installing a newer GCC toolchain does not upgrade libgomp. The only way I found to update libgomp was to build it from source in this PR #152361.
Results
To understand the effect of this libgomp version update, I benchmarked all the following combinations on Arm Neoverse-V1 CPUs:
- models: distilbert, bert-base, bert-large
- modes: eager, compile
- context_length: 32, 64, 128, 256, 512
- dtype: float32, bfloat16 (autocast), int8 (dynamically quantized)
- threads: 4, 8, 16, 32, 48, 64
- pytorch builds: without Build libgomp (gcc-11) from src on AArch64 #152361, with Build libgomp (gcc-11) from src on AArch64 #152361
Below are sample plots showing the percentage of speedup (i.e. 100% means 2x speedup) with the new libgomp (built from source) vs. the current state
We can see that updating libgomp:
- Either improves performance or keep it the same
- It does not cause any meaningful regressions
- the speedups peak at small input lengths, high threads.
benchmark script:
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
from transformers import AutoModel, AutoConfig
import time
import numpy as np
from argparse import ArgumentParser
from contextlib import nullcontext
class ModelArgumentParser(ArgumentParser):
def __init__(self) -> None:
super().__init__(description="huggingface model")
self.add_argument("--context_length",
help="context length - number of input tokens",
type=int,
default=64
)
self.add_argument("--model",
help="model checkpoint - i.e. 'bert-base-uncased'",
type=str,
default='bert-base-uncased')
self.add_argument("--iters",
help="benchmark iterations",
type=int,
default=500)
self.add_argument('--dtype',
choices=['bfloat16', 'float32', 'int8'],
default='float32',
help="datatype",
)
self.add_argument('--mode',
choices=['compile', 'eager'],
default='eager',
help="datatype",
)
if __name__ == "__main__":
parser = ModelArgumentParser()
args = parser.parse_args()
assert not (args.dtype == "compile" and args.dtype == "int8")
model_name = args.model
config = AutoConfig.from_pretrained(model_name)
batch_size = 1
model = AutoModel.from_pretrained(model_name)
maybe_autocast_ctx = torch.autocast(device_type="cpu", dtype=torch.bfloat16) if args.dtype == "bfloat16" else nullcontext()
maybe_disable_torch_function = torch._C.DisableTorchFunction() if args.dtype == "int8" else nullcontext()
with torch.no_grad(),maybe_autocast_ctx,maybe_disable_torch_function:
if args.dtype == "int8":
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
model.eval()
inputs = torch.randint(config.vocab_size, (batch_size, args.context_length), dtype=torch.long, device="cpu")
times = []
if args.mode == "compile":
model = torch.compile(model, fullgraph=True)
# warmup
for _ in range(10):
model(inputs)
# benchmark
for _ in range(args.iters):
s = time.time_ns()
model(inputs)
times.append((time.time_ns() - s) / 1e6)
avg_time = np.mean(times)
print("Time: " + f"{avg_time:.2f} ms")
cc @msaroufim @jerryzh168 @malfet @snadampal @milpuz01 @aditew01 @nikhil-arm