8000 GitHub - facebookresearch/any4: Quantize transformers to any learned arbitrary 4-bit numeric format
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Quantize transformers to any learned arbitrary 4-bit numeric format

License

Notifications You must be signed in to change notification settings

facebookresearch/any4

Repository files navigation

any4 + tinygemm

This repo contains the tinygemm low-letency / small batch size Nvidia GPU GEMM library which implements bf16/fp16, int4 grouped quantization, any4 grouped quantization and MX4 quantization, and the code containing the technique to learn any4 quantization codes.

This code release is meant to accompany the accepted ICML 2025 paper any4: Learned 4-bit Numeric Representation for LLMs by Mostafa Elhoushi and Jeff Johnson, which will be available on arXiv soon.

The techique and code for learning any4 representations and quantizing a model was authored by Mostafa Elhoushi (previously Meta FAIR SysML research). The Nvidia GPU tinygemm library was authored by Jeff Johnson (currently Meta FAIR SysML reserach). An extremely early version of the tinygemm kernels without any4/MX4 support were upstreamed to PyTorch core in Q4 2023 for use by the Torch compiler.

What is any4?

There is a wide variety of 4-bit numerical formats implemented on CPU/GPU for ML inference, such as uniform int4 quantization, "fp4", NF4, AF4 and the like, all of which have the dequantization values fixed a priori. any4 substitutes a lookup table (LUT) to translate the 16 possible 4-bit quantization codes to any arbitrary bfloat16 or float16 floating-point value, and this GPU in-register LUT is used at dequantization time. Each row of a weight matrix can use a different 16 x bfloat16/float16 LUT, so the quantization codes can be tailored to each row of a matrix. k-means or neural network based clustering is used to learn the any4 LUTs based off the weight matrix data distribution. Effectively, any4 is 4-bit grouped quantization like typical int4 quantization, just that instead of the code dequantization values prior to scale and offset being integers in the range [-8, +7] or [0, 15], the dequantization values are here arbitrary floating point values from the LUT. any4 is thus a very efficient means of implementing NormalFloat4 (NF4) or AbnormalFloat4 (AF4), whose initial implementations used GPU unfriendly deeply-nested if/else blocks or switch statements.

The tinygemm low-latency GPU GEMM library implements any4 quantization. Learning the any4 quantization codes is not part of tinygemm itself. While tinygemm supports most any arbitrary GEMM size (assuming the reduction/k dimension is a multiple of 16 or 32), it is primarily meant for matrix multiplication problems where one of the m or n problem dimensions (for a (m x k) x (n x k)^t matrix multiplication) is smaller than a GPU tensor core tile size (e.g., 1 <= m <= 16 or 1 <= n <= 8), usually applied to the "activation" vector in neural networks.

tinygemm has two different modes, one that computes Y = X W^t and the other that computes Y = (W X^t)^t (both produce the same result, just whether the "weight" matrix is the "A" or "B" matrix for tensor core usage). All needed transpositions are performed on the fly as needed by tinygemm. For the m16n8k16 A100+ bf16/fp16 tensor core tile, the "A" matrix tile size is 16 x 16 and "B" is 8 x 16 (or 16 x 8 as desired). Putting activations (e.g., a 1 x k matrix) on the right and weight on the left (so that the 1 x k matrix will occupy the "B" tile) ensures that we will be running the tensor core unit at 1/8th throughput rather than 1/16th throughput. We have found that using the tensor core in this fashion for e.g., GEMV is pretty fast. tinygemm does not use larger tensor core multiplication primitives (again, because a typical use case is something like a (1 x k) x (n x k) GEMM. All matrices presented to tinygemm must be row-major with the reduction dimension k being innermost.

To further reduce latency, it is best to lay out weight matrices in "tensor core" format, so no shared memory transposition is needed. Because there is also no reuse of the weight matrix in usual circumstances, we avoid shared memory entirely for buffering or transposition and the kernels load data directly from gmem into registers (though with some degree of multi-buffering into registers, but nvcc/ptxas' register usage heuristics are at odds with this; loads from gmem into a register are still asynchronous until the point of first use).

Please defer to the paper for additional details.

Getting Started

  1. Clone Repo
git clone git@github.com:fairinternal/any4.git

cd any4
  1. Setup Environment
conda create --name any4 python=3.10
conda activate any4

pip install -r requirements.txt
  1. Access Models

Some models (e.g., Llama) require permission. Follow these steps to access them:

a. Submit a request to access a Llama checkpoint, e.g., https://huggingface.co/meta-llama/Llama-3.2-1B.

b. Setup Hugging Face token access by following the steps described here.

c. Then you will be able to login to Hugging Face by running the cell below and e 8000 ntering the token you obtain from Step b. above:

huggingface-cli login
  1. Install tinygemm kernels
cd tinygemm
python setup.py install
cd ..

Run

Most of the scripts below will run baseline fp16 model by default. To quantize add the following arguments:

  • --model-args: pass in any args that are passed to Hugging Face's from_pretrained() function, including load_in_4bit and load_in_8bit.
  • --quantize: implements different (fake) quantization algorithms implemented in this codebase. It can take: intq (integer quantization), fp4 (4-bit float quantization), nf4 (4-bit normal float quantization), anyq (proposed lookup table quantization).
    • --quantize-args: comma-separated arguments to pass to a quantization algorithm, e.g., --quantize-args n_bit=4,group_size=32 will perform 4-bit quantization with group size 32.
  • --bnb-args: comma-separated arguments to pass to BitsAndBytesConfig, e.g., load_in_4bit=True,bnb_4bit_compute_dtype=fp32
  • --torchao-args: TBD

Quick Example

To run a simple text generation (with and without) quantization example script that you can try and edit:

python example.py

Generation

TBD

Evaluation

Evaluate a model (with or without quantization) on downstream tasks.

  • Baseline fp16 model:
python eval.py --model-name facebook/opt-125m --tasks piqa
  • Quantized int4 model:
python eval.py --model-name facebook/opt-125m --quantize intq --tasks piqa

Arguments:

  • --tasks: by default it runs a large number of natural language, coding, and perplexity evaluation tasks:

Analyze

To analyze weights and mean square errors on weights and activations between baseline model and quantized model at each layer:

python analyze.py --model-name meta-llama/Llama-3.2-1B --quantize nf4

Calibrate

To pass a dataset or pompt over a model and store output activations of each layer:

python calibrate.py --model-name meta-llama/Llama-3.2-1B --dataset cerebras/SlimPajama-627B --num-batches 10

Diff

To pass a prompt to both a baseline model and quantized model and measure the mean square error along each layer:

python analyze.py --model-name meta-llama/Llama-3.2-1B --quantize anyq

Test

python -m pytest .

TODOs:

  • Add Notebook
  • Integrate with torchao
  • any4 LUT dequantization is currently via warp shuffle in the GEMM core, but higher throughput might be achievable by using smem to dequantize 2 x any4 codes (1 byte) at a time instead at the possible expense of added bank conflights.

License

tinygemm and any4 quantization code are CC-BY-NC 4.0 licensed, as found in the LICENSE file.

About

Quantize transformers to any learned arbitrary 4-bit numeric format

Resources

License

Code of conduct

Security policy

Stars

Watchers

Forks

Packages

No packages published
0