From 657bf4f90b36d8b01ad1412fa49d7fdb606871b0 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 3 Apr 2025 06:55:55 -0700 Subject: [PATCH] add fallback kernel and interface Summary: Fallback kernel is used for covering a) when arch specific kernels not available and b) input conditions dont meet optimized kernel's requirements. Plus added kernel selection function taht should later be replaced by some kind of registry Reviewed By: metascroy Differential Revision: D71370598 --- .../cpu/aarch64/tests/test_qmatmul.cpp | 1 + .../channelwise_8bit_a_channelwise_8bit_b.h | 133 ++++++ .../kernels/cpu/interface/quantized_matmul.h | 88 ++++ .../cpu/interface/test_qmatmul_interface.cpp | 448 ++++++++++++++++++ 4 files changed, 670 insertions(+) create mode 100644 torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h create mode 100644 torchao/experimental/kernels/cpu/interface/quantized_matmul.h create mode 100644 torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index 05dbf13aac..ff4f915b2d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -70,6 +70,7 @@ struct test_channelwise_8bit_channelwise_8bit_b< false, false> { static void Run(int m, int k, int n, int stride = 1) { + // TODO: make use of stride for this kernel auto test_case = torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case:: generate(m, k, n, a_has_zeros, a_has_zeros, false, false); diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h new file mode 100644 index 0000000000..3b070eb2b3 --- /dev/null +++ b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h @@ -0,0 +1,133 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b::internal { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + const int8_t* lhs_qvals = static_cast(lhs); + const int8_t* rhs_qvals = static_cast(rhs); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * lhs_stride_m + k_idx; + int rhs_idx = k_idx * rhs_stride_n + n_idx; + if (b_transposed) { + rhs_idx = n_idx * rhs_stride_n + k_idx; + } + + float lhs_dequant = lhs_scales[m_idx * lhs_qparams_stride] * + (static_cast(lhs_qvals[lhs_idx]) - + static_cast( + lhs_zero_points[m_idx * lhs_qparams_stride])); + + float rhs_dequant = rhs_scales[n_idx * rhs_qparams_stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast( + rhs_zero_points[n_idx * rhs_qparams_stride])); + + res += lhs_dequant * rhs_dequant; + } + output[m_idx * n + n_idx] = res; + } + } + } +}; + +} // namespace + // channelwise_8bit_a_channelwise_8bit_b::internal +} // namespace torchao::kernels::cpu::fallback::quantized_matmul + +// TODO: Remove all ::kernels. No need for extra namespace. +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + channelwise_8bit_a_channelwise_8bit_b::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b +} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h new file mode 100644 index 0000000000..01a4c704c5 --- /dev/null +++ b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h @@ -0,0 +1,88 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#include +#include +#endif // defined(__aarch64__) || defined(__ARM_NEON) + +namespace torchao::kernels::cpu::quantized_matmul { + +/* +a_stride_m: stride of a in memory to indiciate how far apart each row is. +b_stride_n: stride of b in memory to indiciate how far apart each row is. +If b is transposed (n x k), then this is how many bytes to skip to get to the +next row. If b is not transposed (k x n), then this is how many bytes to skip to +get to the next row. + +It also returns the stride of a and b, that should be used in the kernel. + +Will need to think of a better way to find the right +ukernel. Perhaps via ukernelconfig + registry?. +*/ +using int8_a_int8_b_channelwise_fp32_c_qmatmul_type = void (*)( + int, + int, + int, + const void*, + int, + const void*, + int, + float*, + int, + const int8_t*, + const int8_t*, + const float*, + const float*, + const int, + const int); + +int8_a_int8_b_channelwise_fp32_c_qmatmul_type +get_int8_a_int8_b_channelwise_qmatmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n); + +int8_a_int8_b_channelwise_fp32_c_qmatmul_type +get_int8_a_int8_b_channelwise_qmatmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (!a_transposed && b_transposed && n >= 8) { + a_stride_m = k; + b_stride_n = k; + return aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot:: + kernel; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + assert(!a_transposed); + if (b_transposed) { + a_stride_m = k; + b_stride_n = k; + return torchao::kernels::cpu::fallback::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b::kernel; + } else { + return torchao::kernels::cpu::fallback::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b::kernel; + } +} +} // namespace torchao::kernels::cpu::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp new file mode 100644 index 0000000000..3629f0960b --- /dev/null +++ b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp @@ -0,0 +1,448 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +#include +#include + +float kTol = 0.0001; + +// This is unfortunately had to be copied over because code in test_utils.h +// depends on quantization kernels which are only buildable for ARM. +// I would like the testing code in this folder to be independent of the arch. +namespace { +void get_qvals_range(int& qmin, int& qmax, int nbit, bool is_symmetric) { + if (is_symmetric) { + qmin = -(1 << (nbit - 1)) + 1; + qmax = -qmin; + } else { + qmin = -(1 << (nbit - 1)); + qmax = (1 << (nbit - 1)) - 1; + } +} + +void get_scale_and_zero( + float& scale, + int& zero, + float vmin, + float vmax, + int qmin, + int qmax) { + assert(qmin < qmax); + assert(vmin < vmax); + scale = (vmax - vmin) / (qmax - qmin); + zero = qmin - std::round(vmin / scale); +} + +inline std::vector +get_random_vector(int size, float min = -1.0, float max = 1.0) { + assert(min < max); + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_real_distribution(min, max), rng); + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +void quantize( + // Output + int8_t* qvals, + // Inputs + const float* vals, + int size, + float scale, + int8_t zero, + int8_t qmin, + int8_t qmax) { + float invScale = 1.0 / (scale + 1e-16); + int i = 0; + auto curr_rounding_mode = fegetround(); + fesetround(FE_TONEAREST); + for (; i < size; ++i) { + // Quantize remaining elements using scalar code + float val = vals[i]; + float qval_f32 = zero + val * invScale; + int32_t qval_s32 = static_cast(std::nearbyint(qval_f32)); + + // Clip to qmin and qmax + qval_s32 = std::max( + static_cast(qmin), + std::min(qval_s32, static_cast(qmax))); + + // Store the quantized value + qvals[i] = static_cast(qval_s32); + } + fesetround(int(curr_rounding_mode)); +} + +auto generate_per_token_quantized_tensor( + int m, + int n, + bool transposed = false) { + auto activations = get_random_vector(m * n, -1.0, 1.0); + auto activation_qvals = std::vector(m * n, 0); + auto activation_scales = std::vector(m, 0); + auto activation_zeros = std::vector(m, 0); + + // Quantize activations with 8-bit asymmetric + // TODO: replace with generic function that does not use aarch64 + // quantize method after we combine with torchao + int qmin, qmax, zero; + float vmin, vmax, scale; + get_qvals_range(qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); + for (int m_idx = 0; m_idx < m; m_idx++) { + auto minmax = std::minmax_element( + activations.data() + m_idx * n, activations.data() + (m_idx + 1) * n); + vmin = *minmax.first; + vmax = *minmax.second; + get_scale_and_zero(scale, zero, vmin, vmax, qmin, qmax); + activation_scales[m_idx] = scale; + activation_zeros[m_idx] = zero; + quantize( + /*qvals=*/activation_qvals.data() + m_idx * n, + /*vals=*/activations.data() + m_idx * n, + /*size=*/n, + scale, + zero, + qmin, + qmax); + } + + if (transposed) { + auto activations_t = std::vector(m * n, 0); + auto activation_qvals_t = std::vector(m * n, 0); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + int activation_idx = m_idx * n + n_idx; + int tranposed_idx = n_idx * m + m_idx; + activations_t[tranposed_idx] = activations[activation_idx]; + activation_qvals_t[tranposed_idx] = activation_qvals[activation_idx]; + } + } + activations = activations_t; + activation_qvals = activation_qvals_t; + } + + return std::make_tuple( + activations, activation_qvals, activation_scales, activation_zeros); +} + +struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { + int m; + int k; + int n; + int stride; + + bool lhs_has_zeros; + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector expected_output; + + std::vector lhs; + std::vector lhs_qvals; + std::vector lhs_scales; + std::vector lhs_zeros; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + int m_, + int k_, + int n_, + int stride_, + bool lhs_has_zeros_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + std::vector expected_output_, + std::vector lhs_, + std::vector lhs_qvals_, + std::vector lhs_scales_, + std::vector lhs_zeros_, + std::vector rhs_, + std::vector rhs_qvals_, + std::vector rhs_scales_, + std::vector rhs_zeros_) + : m(m_), + k(k_), + n(n_), + stride(stride_), + lhs_has_zeros(lhs_has_zeros_), + rhs_has_zeros(rhs_has_zeros_), + lhs_is_transposed(lhs_is_transposed_), + rhs_is_transposed(rhs_is_transposed_), + expected_output(expected_output_), + lhs(lhs_), + lhs_qvals(lhs_qvals_), + lhs_scales(lhs_scales_), + lhs_zeros(lhs_zeros_), + rhs(rhs_), + rhs_qvals(rhs_qvals_), + rhs_scales(rhs_scales_), + rhs_zeros(rhs_zeros_) { + assert(expected_output.size() == m * n); + assert(lhs.size() == m * stride * k); + assert(lhs_qvals.size() == m * stride * k); + assert(lhs_scales.size() == m * stride); + assert(lhs_zeros.size() == m * stride); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == n * stride); + assert(rhs_zeros.size() == n * stride); + } + + static channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case generate( + int m, + int k, + int n, + bool lhs_has_zeros, + bool rhs_has_zeros, + bool lhs_is_transposed, + // rhs_is_transposed means generated b matrix is mxk instead of kxm + bool rhs_is_transposed, + int stride = 1) { + assert(!lhs_is_transposed); + assert(lhs_has_zeros); + assert(rhs_has_zeros); + assert(rhs_is_transposed || stride == 1); + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + generate_per_token_quantized_tensor(m * stride, k); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + generate_per_token_quantized_tensor(n * stride, k, !rhs_is_transposed); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + // Compute expected output + std::vector expected_output(m * n); + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * stride * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx * stride; + if (rhs_is_transposed) { + rhs_idx = n_idx * stride * k + k_idx; + } + float lhs_dequant = lhs_scales[m_idx * stride] * + (lhs_qvals[lhs_idx] - lhs_zeros[m_idx * stride]); + + float rhs_dequant = rhs_scales[n_idx * stride] * + (rhs_qvals[rhs_idx] - rhs_zeros[n_idx * stride]); + + res += lhs_dequant * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = res; + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + m, + k, + n, + stride, + lhs_has_zeros, + rhs_has_zeros, + lhs_is_transposed, + rhs_is_transposed, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; +} // namespace + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct test_channelwise_8bit_channelwise_8bit_b { + static void Run(int m, int k, int n); +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + true> { + static void Run(int m, int k, int n, int stride = 1) { + auto test_case = + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case::generate( + m, k, n, a_has_zeros, a_has_zeros, false, true, stride); + + int a_stride_m, b_stride_n; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_int8_a_int8_b_channelwise_qmatmul( + m, n, k, false, true, a_stride_m, b_stride_n); + a_stride_m = a_stride_m * stride; + b_stride_n = b_stride_n * stride; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + a_stride_m /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + b_stride_n /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +TEST(test_channelwise_8bit_channelwise_8bit_b, TranposedBWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, TranposeBWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/5); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/2, /*n=*/1); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposeBWithZeroPointsLargeMStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16, 5); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19, 16); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallbackStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/5, 7); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/2, /*n=*/1, 32); +}