8000 MHA refac: rope without complex operations + query only as input of the forward by vince62s · Pull Request #20 · eole-nlp/eole · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

MHA refac: rope without complex operations + query only as input of the forward #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions eole/decoders/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,6 @@ def _compute_dec_mask(self, tgt_pad_mask, future):
dec_mask = tgt_pad_mask
return dec_mask

def _forward_self_attn(self, norm_layer_in, dec_mask, step, return_attn=False):
if self.self_attn_type in ["scaled-dot", "scaled-dot-flash"]:
return self.self_attn(
norm_layer_in,
norm_layer_in,
norm_layer_in,
mask=dec_mask,
sliding_window=self.sliding_window,
step=step,
return_attn=return_attn,
)
elif self.self_attn_type == "average":
return self.self_attn(norm_layer_in, mask=dec_mask, step=step)
else:
raise ValueError(f"self attention {type(self.self_attn)} not supported")


class TransformerDecoderBase(DecoderBase):
def __init__(
Expand Down
21 changes: 14 additions & 7 deletions eole/decoders/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,21 @@ def _forward(

norm_layer_in = self.input_layernorm(layer_in)

self_attn, _ = self._forward_self_attn(
norm_layer_in, dec_mask, step, return_attn=return_attn
self_attn, _ = self.self_attn(
norm_layer_in,
mask=dec_mask,
sliding_window=self.sliding_window,
step=step,
return_attn=return_attn,
)
if self.dropout_p > 0:
self_attn = self.dropout(self_attn)

if self.parallel_residual:
ctx_attn, attns = self.context_attn(
enc_out,
enc_out,
norm_layer_in,
key=enc_out,
value=enc_out,
mask=src_pad_mask,
return_attn=return_attn,
)
Expand All @@ -115,7 +119,11 @@ def _forward(
else:
norm_query = self.precontext_layernorm(self_attn + layer_in)
ctx_attn, attns = self.context_attn(
enc_out, enc_out, norm_query, mask=src_pad_mask, return_attn=return_attn
norm_query,
key=enc_out,
value=enc_out,
mask=src_pad_mask,
return_attn=return_attn,
)
if self.dropout_p > 0:
ctx_attn = self.dropout(ctx_attn)
Expand Down Expand Up @@ -234,7 +242,6 @@ def _init_cache(self, device):
"values": torch.tensor([], device=device),
},
)
if hasattr(layer.self_attn, "rope"):
layer.self_attn.rope = layer.self_attn.rope.to(device)
if hasattr(layer.self_attn, "cos") and layer.self_attn.cos is not None:
layer.self_attn.cos = layer.self_attn.cos.to(device)
layer.self_attn.sin = layer.self_attn.sin.to(device)
12 changes: 8 additions & 4 deletions eole/decoders/transformer_lm_decoder.py
< 8000 /tr>
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,14 @@ def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=Fals

norm_layer_in = self.input_layernorm(layer_in)

attn_output, attns = self._forward_self_attn(
norm_layer_in, dec_mask, step, return_attn=return_attn
attn_output, attns = self.self_attn(
norm_layer_in,
mask=dec_mask,
sliding_window=self.sliding_window,
step=step,
return_attn=return_attn,
)

if self.dropout_p > 0:
attn_output = self.dropout(attn_output)
if self.parallel_residual:
Expand Down Expand Up @@ -156,7 +161,6 @@ def _init_cache(self, device, mask):
"key_pad_mask": mask,
},
)
if hasattr(layer.self_attn, "rope"):
layer.self_attn.rope = layer.self_attn.rope.to(device)
if hasattr(layer.self_attn, "cos") and layer.self_attn.cos is not None:
layer.self_attn.cos = layer.self_attn.cos.to(device)
layer.self_attn.sin = layer.self_attn.sin.to(device)
4 changes: 1 addition & 3 deletions eole/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def forward(self, layer_in, mask):
* layer_out ``(batch_size, src_len, model_dim)``
"""
norm_layer_in = self.input_layernorm(layer_in)
context, _ = self.self_attn(
norm_layer_in, norm_layer_in, norm_layer_in, mask=mask
)
context, _ = self.self_attn(norm_layer_in, mask=mask)
if self.dropout_p > 0:
context = self.dropout(context)
if self.parallel_residual:
Expand Down
106 changes: 55 additions & 51 deletions eole/modules/multi_headed_attn.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,13 @@
# are both < 2048 tokens.


def rotaryembeddings(dim: int, maxseqlen=2048, base=10000, device=None):
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
tmax = torch.arange(maxseqlen, device=inv_freq.device)
def rotaryembeddings(dim: int, maxseqlen=2048, base=10000, device=torch.device("cpu")):
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
tmax = torch.arange(maxseqlen, device=device)
rope = torch.outer(tmax, inv_freq).float()
# rope is now matrix [maxseqlen, dim/2]
rope = torch.polar(torch.ones_like(rope), rope)
rope = torch.cat((rope, rope), dim=1)
if device is not None:
rope = rope.to(device)
cos = rope[:, : rope.size(1) // 2].real.contiguous().half()
sin = rope[:, : rope.size(1) // 2].imag.contiguous().half()
return rope, cos, sin
cos_emb = torch.cos(rope).half()
sin_emb = torch.sin(rope).half()
return cos_emb, sin_emb


def rotate_half(x):
Expand All @@ -40,36 +35,36 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_emb(query, key, rope, interleave):
def apply_rotary_emb(query, key, cos_pos, sin_pos, interleave):
if interleave:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query_ = query.float().reshape(*query.shape[:-1], -1, 2)
query_ = torch.view_as_complex(query_)
key_ = key.float().reshape(*key.shape[:-1], -1, 2)
key_ = torch.view_as_complex(key_)
rope = rope[:, : rope.size(1) // 2].view(1, query_.size(1), 1, query_.size(3))
query_out = torch.view_as_real(query_ * rope).flatten(3)
key_out = torch.view_as_real(key_ * rope).flatten(3)
return query_out.transpose(1, 2).type_as(query), key_out.transpose(
1, 2
).type_as(key)
query_ = query.float().reshape(
*query.shape[:-1], -1, 2
) # [B, H, L, dimperhead//2, 2]]
query_cos = query_[..., 0] * cos_pos - query_[..., 1] * sin_pos
query_sin = query_[..., 0] * sin_pos + query_[..., 1] * cos_pos
query_ = torch.stack((query_cos, query_sin), dim=-1).flatten(3)
key_ = key.float().reshape(
*key.shape[:-1], -1, 2
) # [B, H, L, dimperhead//2, 2]]
key_cos = key_[..., 0] * cos_pos - key_[..., 1] * sin_pos
key_sin = key_[..., 0] * sin_pos + key_[..., 1] * cos_pos
key_ = torch.stack((key_cos, key_sin), dim=-1).flatten(3)
return query_.type_as(query), key_.type_as(key)
else:
cos, sin = rope.real, rope.imag
rotary_dim = cos.size(1)
rotary_dim = cos_pos.size(1)
head_dim = query.size(3)
if rotary_dim < head_dim:
q_embed = (query[:, :, :, :rotary_dim] * cos) + (
rotate_half(query[:, :, :, :rotary_dim]) * sin
q_embed = (query[:, :, :, :rotary_dim] * cos_pos) + (
rotate_half(query[:, :, :, :rotary_dim]) * sin_pos
)
k_embed = (key[:, :, :, :rotary_dim] * cos) + (
rotate_half(key[:, :, :, :rotary_dim]) * sin
k_embed = (key[:, :, :, :rotary_dim] * cos_pos) + (
rotate_half(key[:, :, :, :rotary_dim]) * sin_pos
)
q_embed = torch.cat([q_embed, query[:, :, :, rotary_dim:]], dim=-1)
k_embed = torch.cat([k_embed, key[:, :, :, rotary_dim:]], dim=-1)
else:
q_embed = (query * cos) + (rotate_half(query) * sin)
k_embed = (key * cos) + (rotate_half(key) * sin)
q_embed = (query * cos_pos) + (rotate_half(query) * sin_pos)
k_embed = (key * cos_pos) + (rotate_half(key) * sin_pos)
return q_embed.type_as(query), k_embed.type_as(key)


Expand Down Expand Up @@ -350,7 +345,7 @@ def __init__(
self.rotary_dim = self.dim_per_head
else:
self.rotary_dim = model_config.rotary_dim
self.rope, self.cos, self.sin = rotaryembeddings(
self.cos, self.sin = rotaryembeddings(
self.rotary_dim, base=model_config.rotary_theta
)
self.rotary_interleave = model_config.rotary_interleave
Expand Down Expand Up @@ -390,9 +385,9 @@ def update_dropout(self, dropout: float) -> None:

def forward(
self,
key: Tensor,
value: Tensor,
query: Tensor,
key: Optional[Tensor] = None,
value: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
sliding_window: Optional[int] = 0,
step: Optional[int] = 0,
Expand All @@ -402,15 +397,17 @@ def forward(
Compute the context vector and the attention vectors.

Args:
query (Tensor): set of `query_len`
query vectors ``(batch, query_len, dim)``
key (Tensor): set of `key_len`
key vectors ``(batch, key_len, dim)``
value (Tensor): set of `key_len`
value vectors ``(batch, key_len, dim)``
query (Tensor): set of `query_len`
query vectors ``(batch, query_len, dim)``
mask: binary mask 1/0 indicating which keys have
zero / non-zero attention ``(batch, query_len, key_len)``
sliding_window (int): sliding context for Mistral based models
step (int): decoding step (used for Rotary embedding)
return_attn (bool): whetehr to return the attention tensor
Returns:
(Tensor, Tensor):

Expand Down Expand Up @@ -444,17 +441,20 @@ def forward(
or query.dtype != torch.float16
):
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen + start_pos > self.rope.size(0):
if seqlen + start_pos > self.cos.size(0):
# Resize rotary embeddings.
self.rope, self.cos, self.sin = rotaryembeddings(
self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + start_pos + 2048),
base=self.rotary_theta,
device=self.rope.device,
device=self.cos.device,
)
rope = self.rope[start_pos : start_pos + seqlen]
cos, sin = (
self.cos[start_pos : start_pos + seqlen],
self.sin[start_pos : start_pos + seqlen],
)
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
query, key, cos, sin, interleave=self.rotary_interleave
)

if self.layer_cache[1]["keys"].numel() != 0:
Expand Down Expand Up @@ -495,16 +495,16 @@ def forward(
)
if (
self.max_relative_positions == -1
and start_pos + 32 >= self.rope.size(0)
and start_pos + 32 >= self.cos.size(0)
):
# Resize rotary embeddings.
# We take a margin of 32 tokens as the kv_cache
# is incremented by 32 tokens every 32 tokens.
self.rope, self.cos, self.sin = rotaryembeddings(
self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
maxseqlen=(start_pos + 2048),
base=self.rotary_theta,
device=self.rope.device,
device=self.cos.device,
)

if sliding_window > 0 and key.size(2) > sliding_window:
Expand Down Expand Up @@ -560,8 +560,10 @@ def forward(
key_pad_mask = self.layer_cache[1]["key_pad_mask"]
else:
# Retrieve keys and values from linear layers (training mode).
key = self.maybe_ckpt(self.linear_keys, key)
value = self.maybe_ckpt(self.linear_values, value)
key = self.maybe_ckpt(self.linear_keys, key if key is not None else query)
value = self.maybe_ckpt(
self.linear_values, value if value is not None else query
)
query = self.maybe_ckpt(self.linear_query, query)

key = shape(key, self.dim_per_head)
Expand All @@ -571,17 +573,19 @@ def forward(
if self.max_relative_positions == -1: # Rotary Embeddings
start_pos = 0
seqlen = query.size(2)
if seqlen > self.rope.size(0):
if seqlen > self.cos.size(0):
# Resize rotary embeddings.
self.rope, self.cos, self.sin = rotaryembeddings(
self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
device=query.device,
)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
cos, sin = self.cos[start_pos : start_pos + seqlen].to(
query.device
), self.sin[start_pos : start_pos + seqlen].to(query.device)
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
query, key, cos, sin, interleave=self.rotary_interleave
)

b, h, l, d = key.size()
Expand Down
Loading
0