8000 MHA refac: rope without complex operations + query only as input of the forward by vince62s · Pull Request #20 · eole-nlp/eole · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

MHA refac: rope without complex operations + query only as input of the forward #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2024

Conversation

vince62s
Copy link
Contributor

Regarding the rope refac posting here the logic for legacy:

import torch

dim = 16
base = 10000
maxseqlen = 2048


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
tmax = torch.arange(maxseqlen, device=inv_freq.device)
rope = torch.outer(tmax, inv_freq).float()
# rope is now matrix [maxseqlen, dim/2]

cos_emb = torch.cos(rope)
sin_emb = torch.sin(rope)


rope1 = torch.polar(torch.ones_like(rope), rope)
rope1 = torch.cat((rope1, rope1), dim=1)
print(rope1.size()) # [2048, 16]
start_pos = 4

query = torch.randn(8, 4, 7, 16)


#OLD CODE from llama logic (using complex operators)
query_ = query.float().reshape(8, 4, 7, -1, 2) # [8, 4, 7, 8, 2]
print(query_.size(), query_)
query_ = torch.view_as_complex(query_)
print(query_.size(), query_) # [8, 4, 7, 8] but each is a complex a + bj
print(rope1.size())
rope1 = rope1[start_pos:start_pos+query_.size(2), :rope1.size(1) //2].view(1, 1, query_.size(2), query_.size(3))
print(rope1.size()) # [1, 1, 7, 8]
print(query_ * rope1)
query_out = torch.view_as_real(query_ * rope1).flatten(3)
print(query_out.size(), query_out) # [8, 4, 7, 16]


# same maths but with cos/sin only
query_interleaved = query.reshape(query.shape[0], query.shape[1], query.shape[2], -1, 2)
print(query_interleaved.size())
cos_pos = cos_emb[start_pos:start_pos + query_interleaved.size(2)]
sin_pos = sin_emb[start_pos:start_pos + query_interleaved.size(2)]
print(cos_pos.size(), sin_pos.size())

# Apply rotary embeddings using cosine and sine functions
q_embed_cos = query_interleaved[..., 0] * cos_pos - query_interleaved[..., 1] * sin_pos
q_embed_sin = query_interleaved[..., 0] * sin_pos + query_interleaved[..., 1] * cos_pos

# Combine cosine and sine embeddings
q_embed = torch.stack((q_embed_cos, q_embed_sin), dim=-1)

# Flatten and reshape the output
query_out1 = q_embed.flatten(3)
print(query_out1.size(), query_out1)

print(torch.allclose(query_out, query_out1))

@vince62s vince62s merged commit acd7b9a into eole-nlp:main Jun 11, 2024
2 checks passed
vince62s added a commit that referenced this pull request Jun 12, 2024
vince62s added a commit that referenced this pull request Jun 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant
0