-
Notifications
You must be signed in to change notification settings - Fork 10
Muxas/gqa #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comment is that let's stick to numpy.testing
rather than torch.testing
. NumPy provides a richer set of assertions. Its core idea is a drop-in replacement for built-in assert
statement with pretty printing and numerical precision controls, i.e.
from numpy.testing import assert_allclose, assert_equal, ...
assert_equal(lhs, rhs)
Mixing different APIs could be potentially confusing for a reader. Also, inter-op in PyTorch is a quite lame.
n_head=16, | ||
n_head_tile=8, | ||
n_head_kv=4, | ||
dtype=np.dtypes.Float32DType(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just np.float32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just repeated how it is done in a batch norm test. We have to either use Torch dtypes, or rely on our own dtypes. Supporting bf16, fp16, fp8 and quantised formats is out of Numpy scope.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#102 is to resolve this issue in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supporting bf16, fp16, fp8 and quantised formats is out of Numpy scope.
Absolutely not. There are widely-adopted custom types (e.g. pandas
). Also, NumPy facilitates an extensions with NEP-42. Thanks to JAX team, we have ml-dtypes which provide in framework-agnostic way common low-bits floating types.
[1] NEP-42
[2] numpy/numpy-user-dtypes
[3] jax-ml/ml-dtypes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NEP-42 is yet under construction, as the link says, but in the future it will allow us to add our own quantised formats. JAX ml-dtypes can be used for now, as it covers bfloat16 and float8 types. We need these only for testing purposes, as an interoperation between NNTile and Numpy is through upcasting NNTile data into a float
for all 32 and less bit floats and into a double
for 64-bit floats.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need these only for testing purposes, as an interoperation between NNTile and Numpy is through upcasting NNTile data into a float for all 32 and less bit floats and into a double for 64-bit floats.
Then we do not actually need dtypes. Just define upcast()
routine and expose it to Python with limited visiblity. That's all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a distinct topic. I don't understand why you brought it now.
Original comment was about replacing internal numpy.dtypes.*
in favor of numpy.float32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just made it a string. Looks much simpler now
self.w_q.value.wont_use() | ||
# Apply bias if needed | ||
if self.in_proj_bias_q is not None: | ||
# batched add_fiber (head_size, batch=(kv_group_size, n_head_kv)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does internal tuple batch=() mean? What shape should return in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a virtual union of several dimensions into a single batch dimension. Parameter batch_ndim=2
is set in the next line.
This PR implements LlamaAttention without RotaryEmbedding. It is checked against LlamaAttention from transformers by providing zeroed argument
position_ids
. @daskol please check if test can be further improved for readability andpytest
ideology.