8000 Poor scaling on AArch64 at high thread counts · Issue #155795 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Poor scaling on AArch64 at high thread counts #155795
Open
@fadara01

Description

@fadara01

It's a bit more than poor scaling. ML workloads generally regress in performance at 32 threads and degrade further at 64 threads (compared to 16 threads).

This was observed while benchmarking LLaMA 3.1 with vLLM, where we noticed that paged_attention at 32 threads runs slower than at 16, and even slower at 64.
Running the paged_attention benchmarks under perf revealed that:

  • vLLM uses the libgomp (version 8.5) bundled in the PyTorch wheel (built in the manylinux_2_28_aarch64 Docker image).
  • 70% of total time is spent in gomp_iter_dynamic_next.
    98.16%     0.00%  python   libc.so.6                                   [.] thread_start
            |
            ---thread_start
               |
                --98.16%--start_thread
                          |
                           --97.94%--gomp_thread_start
                                     |
                                     |--90.08%--paged_attention_v1_impl<float,…>::call [._omp_fn.0]
                                     |          |
                                     |          |--74.07%--gomp_iter_dynamic_next
                                     |          |
                                     |           --7.00%--reduceValueBlock<float,…>::λ(int)
                                     |
                                      --7.55%--gomp_thread_start
                                                |
                                                |--4.17%--gomp_barrier_wait_end
                                                |
                                                 --3.38%--gomp_team_barrier_wait_end

Interestingly, running with my local libgomp 11.4 (using LD_PRELOAD) yields out-of-the-box speedups at high thread counts, and the perf output looks more reasonable - only 1.22% in gomp_iter_dynamic_next:

    97.60%     0.00%  python   libc.so.6                                   [.] thread_start
            |
            ---thread_start
               |
                --97.60%--start_thread
                          |
                          |--96.98%--gomp_thread_start
                          |          |
                          |          |--75.30%--paged_attention_v1_impl<float,…>::call [._omp_fn.0]
                          |          |          |
                          |          |          |--28.96%--reduceValueBlock<float,…>::λ(int)
                          |          |          |
                          |          |          |--1.22%--gomp_iter_dynamic_next
                          |          |          |          |
                          |          |          |           --1.15%--__aarch64_ldadd8_sync
                          |          |          |
                          |          |          |--1.08%--expf@@GLIBC_2.27
                          |          |          |
                          |          |           --0.95%--FP32Vec16::reduce_sum()
                          |          |
                          |          |--11.56%--gomp_barrier_wait_end
                          |          |
                          |           --10.10%--gomp_team_barrier_wait_end
                          |
                           --0.61%--blas_thread_server

This mirrors past regressions on AArch64 at high thread counts, previously attributed to suboptimal thread-throttling in Arm Compute Library.

To confirm, I ran the benchmark script attached below with: PyTorch-wheel libgomp (8.5) vs. local version of libgomp (11.4) (see From those benchmarks, we can see for example that the following benchmark:
python bench.py --model bert-base-uncased --mode eager --context_length 128 --dtype int8

  • With the wheel’s libgomp, runs 3× slower at 64 threads vs. 16 threads.
  • With the local libgomp (with LD_PRELOAD), performance is comparable between 16 and 64 threads (no speedup, but no slowdown).
    While bert-base isn’t typically run on 64 threads, a 3× slowdown relative to 16 threads is unexpected.

How can this be addressed in PyTorch release/nightly wheels?

Unfortunately, the current manylinux_2_28_aarch64 image used to build AArch64 wheels (AlmaLinux 8) only provides libgomp 8.5—installing a newer GCC toolchain does not upgrade libgomp. The only way I found to update libgomp was to build it from source in this PR #152361.

Results

To understand the effect of this libgomp version update, I benchmarked all the following combinations on Arm Neoverse-V1 CPUs:

Below are sample plots showing the percentage of speedup (i.e. 100% means 2x speedup) with the new libgomp (built from source) vs. the current state
We can see that updating libgomp:

  • Either improves performance or keep it the same
  • It does not cause any meaningful regressions
  • the speedups peak at small input lengths, high threads.

Image

benchmark script:

# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate <open-source-office@arm.com>
# SPDX-License-Identifier: BSD-3-Clause
import torch
from transformers import AutoModel, AutoConfig
import time
import numpy as np
from argparse import ArgumentParser
from contextlib import nullcontext


class ModelArgumentParser(ArgumentParser):
    def __init__(self) -> None:
        super().__init__(description="huggingface model")
        self.add_argument("--context_length",
                            help="context length - number of input tokens",
                            type=int,
                            default=64
        )
        self.add_argument("--model",
                            help="model checkpoint - i.e. 'bert-base-uncased'",
                            type=str,
                            default='bert-base-uncased')
        self.add_argument("--iters",
                            help="benchmark iterations",
                            type=int,
                            default=500)

        self.add_argument('--dtype',
                            choices=['bfloat16', 'float32', 'int8'],
                            default='float32',
                            help="datatype",
        )
        self.add_argument('--mode',
                            choices=['compile', 'eager'],
                            default='eager',
                            help="datatype",
        )

if __name__ == "__main__":
    parser = ModelArgumentParser()
    args = parser.parse_args()
    assert not (args.dtype == "compile" and args.dtype == "int8")
    
    model_name = args.model
    config = AutoConfig.from_pretrained(model_name)
    batch_size = 1
    model = AutoModel.from_pretrained(model_name)

    maybe_autocast_ctx = torch.autocast(device_type="cpu", dtype=torch.bfloat16) if args.dtype == "bfloat16" else nullcontext()
    maybe_disable_torch_function = torch._C.DisableTorchFunction() if args.dtype == "int8" else nullcontext()

    with torch.no_grad(),maybe_autocast_ctx,maybe_disable_torch_function:
        if args.dtype == "int8":
            model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
            
        model.eval()
        inputs = torch.randint(config.vocab_size, (batch_size, args.context_length), dtype=torch.long, device="cpu")

        times = []

        if args.mode == "compile":
            model = torch.compile(model, fullgraph=True)

        # warmup
        for _ in range(10):
            model(inputs)
        # benchmark
        for _ in range(args.iters):
            s = time.time_ns()
            model(inputs)
            times.append((time.time_ns() - s) / 1e6)

        avg_time = np.mean(times)
        print("Time: " + f"{avg_time:.2f} ms")

cc @msaroufim @jerryzh168 @malfet @snadampal @milpuz01 @aditew01 @nikhil-arm

Metadata

Metadata

Assignees

Labels

module: armRelated to ARM architectures builds of PyTorch. Includes Apple M1module: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0