8000 Muxas/gqa by Muxas · Pull Request #101 · nntile/nntile · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Muxas/gqa #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 15, 2024
Merged

Muxas/gqa #101

merged 14 commits into from
Jul 15, 2024

Conversation

Muxas
Copy link
Member
@Muxas Muxas commented Jul 14, 2024

This PR implements LlamaAttention without RotaryEmbedding. It is checked against LlamaAttention from transformers by providing zeroed argument position_ids. @daskol please check if test can be further improved for readability and pytest ideology.

@Muxas Muxas requested review from amkatrutsa and daskol July 14, 2024 11:59
@Muxas Muxas linked an issue Jul 14, 2024 that may be closed by this pull request
Copy link
Member
@daskol daskol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment is that let's stick to numpy.testing rather than torch.testing. NumPy provides a richer set of assertions. Its core idea is a drop-in replacement for built-in assert statement with pretty printing and numerical precision controls, i.e.

from numpy.testing import assert_allclose, assert_equal, ...
assert_equal(lhs, rhs)

Mixing different APIs could be potentially confusing for a reader. Also, inter-op in PyTorch is a quite lame.

n_head=16,
n_head_tile=8,
n_head_kv=4,
dtype=np.dtypes.Float32DType(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just np.float32.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just repeated how it is done in a batch norm test. We have to either use Torch dtypes, or rely on our own dtypes. Supporting bf16, fp16, fp8 and quantised formats is out of Numpy scope.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#102 is to resolve this issue in the future

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supporting bf16, fp16, fp8 and quantised formats is out of Numpy scope.

Absolutely not. There are widely-adopted custom types (e.g. pandas). Also, NumPy facilitates an extensions with NEP-42. Thanks to JAX team, we have ml-dtypes which provide in framework-agnostic way common low-bits floating types.

[1] NEP-42
[2] numpy/numpy-user-dtypes
[3] jax-ml/ml-dtypes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NEP-42 is yet under construction, as the link says, but in the future it will allow us to add our own quantised formats. JAX ml-dtypes can be used for now, as it covers bfloat16 and float8 types. We need these only for testing purposes, as an interoperation between NNTile and Numpy is through upcasting NNTile data into a float for all 32 and less bit floats and into a double for 64-bit floats.

Copy link
Member
@daskol daskol Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need these only for testing purposes, as an interoperation between NNTile and Numpy is through upcasting NNTile data into a float for all 32 and less bit floats and into a double for 64-bit floats.

Then we do not actually need dtypes. Just define upcast() routine and expose it to Python with limited visiblity. That's all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a distinct topic. I don't understand why you brought it now.

Original comment was about replacing internal numpy.dtypes.* in favor of numpy.float32.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just made it a string. Looks much simpler now

self.w_q.value.wont_use()
# Apply bias if needed
if self.in_proj_bias_q is not None:
# batched add_fiber (head_size, batch=(kv_group_size, n_head_kv))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does internal tuple batch=() mean? What shape should return in this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a virtual union of several dimensions into a single batch dimension. Parameter batch_ndim=2 is set in the next line.

@Muxas Muxas merged commit ccbb219 into main Jul 15, 2024
2 of 5 checks passed
@Muxas Muxas deleted the muxas/gqa branch July 15, 2024 11:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Operation: Group Query Attention
3 participants
0