-
Notifications
You must be signed in to change notification settings - Fork 769
[SYCL][CUDA][MATRIX] joint_matrix_bmad implementation #5363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: jack.kirk <jack.kirk@codeplay.com>
Signed-off-by: jack.kirk <jack.kirk@codeplay.com>
Signed-off-by: jack.kirk <jack.kirk@codeplay.com>
Signed-off-by: jack.kirk <jack.kirk@codeplay.com>
@@ -495,14 +562,59 @@ struct joint_matrix_mad_impl< | |||
get_layout_pair_id<LayoutA, LayoutB>(), 0); | |||
} | |||
} | |||
} else if constexpr (std::is_same<T1, double>::value) { | |||
} else if constexpr (M == 8 && N == 8 && K == 4) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change related to bmad addition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this is a superficial/non-important change that I made just for better consistency of the if constexpr
statements in this function.
get_layout_id<Layout>()); | ||
} else if constexpr (NumRows == 128 && NumCols == 8) { | ||
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get()); | ||
__bmma_m8n8k128_ld_b_b1(res.data, tileptr, stride, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what types are supported in bmad? only double and i32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NVPTX bmad requires that matrix elements are stored in 32 bit untyped registers, int32_t is used here because when the NVPTX builtins for these functions were created int32_t register arguments were defined (uint32_t can also be used but there is no difference as far as NVPTX backend is concerned.). As far as the user is concerned, I think that they should work with uint32_t for bmad cases as in intel/llvm-test-suite#760.
double is not supported as a register storage type for bmad in NVPTX and I did not create a case for the user to use double with bmad.
Hi @dkhaldi If it is preferred for reviewing purposes I could add the temporary/initial fp19 implementation that uses uint32_t directly to this PR? Hopefully the uint32_t fp19 should be a bit more straightforward to review compared to the bmad cases, since in the end we realized we can implement the fp19 cases in a way which is completely compliant with the existing matrix extension, whereas the bmad cases require a different interface. Otherwise it is fine to put them up one at a time, I just thought it might make it easier to review them at once. Thanks |
I think separate PRs is better. |
OK |
… dimension matrix elements divided by 32. The stride argument in the joint_matrix_load function now refers to the number of registers to stride rather than the number of matrix elements. This leads to a cleaner example because all factors or 32 can be removed.
// number of cols of b. | ||
constexpr int N = 8; // number of cols of accumulator, | ||
// number of rows of a. | ||
constexpr int K = 128; // number of cols of a/number of rows of b. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you missed to make the change here.
K should be 4 here.
Can you please add a comment where you are making these changes?
Basically, saying that the underlying intrinsics are expecting a shape of K equals to number of total bits, not number of elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks I forgot that. Fixed now, and I also updated the test so it will work with the legacy pass-manager.
I've added a more detailed comment describing Bitwise Dot Product and how this dictates the relation between the number of Array elements used for A/B arrays and the number of single-bit matrix elements that the A/B arrays represent. I've also correspondingly updated the test in intel/llvm-test-suite and the tensor cores matrix extension PR #4695.
// number of cols of b. | ||
constexpr int N = 8; // number of cols of accumulator, | ||
// number of rows of a. | ||
constexpr int K = 4; // number of cols of a/number of rows of b divided by 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
number of cols of a/number of rows of b divided by 32
should be:
number of bits in cols of a/number of bits in rows of B divided by 32.
If this is true, do we need to add the "divided by 32" in the code example.
I meant before to add the "multiplies by 32" in the implementation code to explain that this is how we get number of bits that exist in the intrinsics. But at the user level code, is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here K=4 is not the number of cols in a subgroup matrix: you have to multiply by 32, since K gives a dimension of the arrays A/B which hold the single-bit matrix elements in uint32_t storage type. There are 32 single-bit matrix elements per uint32_t storage type.
I tried to describe the purpose of these bitwise matrix multiplications here without going into too much detail. I added references for full details on the origins of the single-bit models and how they use the bitwise matrix multiplications within them. I do not find references to the usage of such "bitwise matrix multiplications" outside of such models (although of course this does not mean they don't exist/will exist in the future), but I think that
this functionality was introduced specifically with such use cases in mind.
It is important for such users to understand that each bit is considered an element of the matrix by joint_matrix_bmad
(the matrix element is "quantized" to a single bit), which is why in the original implementation I set K = 128. However as you pointed out this leads to lots of factors of 32 because we have to divide by 32 to get the number of uint32_t array elements that are used to store the matrix.
In the current implementation it is nice that these factors are gone, but there should still be proper documentation (see here) describing the relationship between "K" and the actual number of (single-bit) matrix elements. Since this is experimental I think it is normal to expect that once people start using this there could be feedback suggesting small changes to the interface: I'm not sure whether the interface I originally set up that led to the factors of 32 or the one you suggested is preferable for the users, but I imagined that at this experimental stage it can (and I imagine most likely will!) be changed in some way in the future anyway.
I could add back the naming scheme A -> A_Packed, B-> B_Packed that I originally used when I switched to K=128 -> K=4 to make is clearer that I am calling "a" the matrix and "A_Packed" a packed array representation of the matrix?
Then I could also add some more detailed description in both the tests and the implementation? I did not want to go into too much detail in tests/implementation because I thought that the proper place for such descriptions would be in the documentation of the extension? This is why I kept things concise here and did not mention details in the implementation.
Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
/verify with intel/llvm-test-suite#760 |
cc @dkhaldi
Implementation corresponding to the matrix extension proposal section "Bitwise Multiply and Add" in #4695
Integration tests here: intel/llvm-test-suite#760