8000 simplify LayerNorm access as a constant by vince62s · Pull Request #7 · eole-nlp/eole · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

simplify LayerNorm access as a constant #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions eole/config/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from typing import Dict, Union, Literal, Any, Annotated
from pydantic import Field, field_validator, model_validator # , TypeAdapter

from eole import constants
from eole.modules.transformer_mlp import (
ActivationFunction,
) # might be better defined elsewhere
from eole.constants import PositionEncodingType, ActivationFunction
from eole.config.config import Config


Expand All @@ -30,8 +27,8 @@ class EmbeddingsConfig(Config):
description="Use a sin to mark relative words positions. "
"Necessary for non-RNN style models.",
)
position_encoding_type: constants.PositionEncodingType = Field(
default=constants.PositionEncodingType.SinusoidalInterleaved,
position_encoding_type: PositionEncodingType = Field(
default=PositionEncodingType.SinusoidalInterleaved,
description="Type of positional encoding.",
)

Expand Down
23 changes: 23 additions & 0 deletions eole/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Define constant values used across the project."""
from enum import Enum
import torch
from eole.modules.rmsnorm import RMSNorm
import torch.nn.functional as F


class DefaultTokens(object):
Expand Down Expand Up @@ -46,3 +49,23 @@ class ModelTask(str, Enum):
class PositionEncodingType(str, Enum):
SinusoidalInterleaved = "SinusoidalInterleaved"
SinusoidalConcat = "SinusoidalConcat"


class ActivationFunction(str, Enum):
relu = "relu"
gelu = "gelu"
silu = "silu"
gated_gelu = "gated-gelu"
gated_silu = "gated-silu"


ACTIVATION_FUNCTIONS = {
ActivationFunction.relu: F.relu,
ActivationFunction.gelu: F.gelu,
ActivationFunction.silu: F.silu,
ActivationFunction.gated_gelu: F.gelu,
ActivationFunction.gated_silu: F.silu,
}


LayerNorm = {"standard": torch.nn.LayerNorm, "rms": RMSNorm}
2 changes: 1 addition & 1 deletion eole/decoders/cnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn

from eole.modules import ConvMultiStepAttention
from eole.modules.conv_multi_step_attention import ConvMultiStepAttention
from eole.utils.cnn_factory import shape_transform, GatedConv
from eole.decoders.decoder import DecoderBase

Expand Down
3 changes: 2 additions & 1 deletion eole/decoders/rnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from eole.decoders.decoder import DecoderBase
from eole.modules.stacked_rnn import StackedLSTM, StackedGRU
from eole.modules import context_gate_factory, GlobalAttention
from eole.modules.gate import context_gate_factory
from eole.modules.global_attention import GlobalAttention


class RNNDecoderBase(DecoderBase):
Expand Down
20 changes: 7 additions & 13 deletions eole/decoders/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch
import torch.nn as nn
from eole.decoders.decoder import DecoderBase
from eole.modules import MultiHeadedAttention, AverageAttention
from eole.modules.multi_headed_attn import MultiHeadedAttention
from eole.modules.average_attn import AverageAttention
from eole.modules.transformer_mlp import MLP
from eole.modules.moe import MoE
from eole.modules.rmsnorm import RMSNorm
from eole.constants import LayerNorm


class TransformerDecoderLayerBase(nn.Module):
Expand All @@ -23,14 +24,7 @@ def __init__(
model_config (eole.config.TransformerDecoderConfig): full decoder config
"""
super(TransformerDecoderLayerBase, self).__init__()
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.parallel_residual = model_config.parallel_residual
self.shared_layer_norm = model_config.shared_layer_norm
self.dropout_p = getattr(running_config, "dropout", [0.0])[0]
Expand All @@ -39,7 +33,7 @@ def __init__(
self.sliding_window = model_config.sliding_window
self.self_attn_type = model_config.self_attn_type

self.input_layernorm = layernorm(
self.input_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
if self.self_attn_type in ["scaled-dot", "scaled-dot-flash"]:
Expand All @@ -55,11 +49,11 @@ def __init__(
aan_useffn=model_config.aan_useffn,
)
self.dropout = nn.Dropout(self.dropout_p)
self.post_attention_layernorm = layernorm(
self.post_attention_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
if model_config.parallel_residual and not model_config.shared_layer_norm:
self.residual_layernorm = layernorm(
self.residual_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
if model_config.num_experts > 0:
Expand Down
28 changes: 8 additions & 20 deletions eole/decoders/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
TransformerDecoderLayerBase,
TransformerDecoderBase,
)
from eole.modules import MultiHeadedAttention, AverageAttention
from eole.modules.rmsnorm import RMSNorm
from eole.modules.multi_headed_attn import MultiHeadedAttention
from eole.modules.average_attn import AverageAttention
from eole.constants import LayerNorm


class TransformerDecoderLayer(TransformerDecoderLayerBase):
Expand All @@ -35,15 +36,8 @@ def __init__(
model_config,
running_config=running_config,
)
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)
self.precontext_layernorm = layernorm(

self.precontext_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
self.context_attn = MultiHeadedAttention(
Expand Down Expand Up @@ -151,14 +145,6 @@ def __init__(
super(TransformerDecoder, self).__init__(
model_config, running_config=running_config
)
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.transformer_layers = nn.ModuleList(
[
Expand All @@ -170,7 +156,9 @@ def __init__(
]
)
# This is the Decoder out layer norm
self.layer_norm = layernorm(model_config.hidden_size, eps=model_config.norm_eps)
self.layer_norm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)

def forward(self, emb, **kwargs):
"""Decode, possibly stepwise."""
Expand Down
15 changes: 5 additions & 10 deletions eole/decoders/transformer_lm_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
TransformerDecoderLayerBase,
TransformerDecoderBase,
)
from eole.modules.rmsnorm import RMSNorm
from eole.constants import LayerNorm


class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
Expand Down Expand Up @@ -86,14 +86,7 @@ def __init__(
running_config=None,
):
super(TransformerLMDecoder, self).__init__(model_config)
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.transformer_layers = nn.ModuleList(
[
TransformerLMDecoderLayer(
Expand All @@ -104,7 +97,9 @@ def __init__(
]
)
# This is the Decoder out layer norm
self.layer_norm = layernorm(model_config.hidden_size, eps=model_config.norm_eps)
self.layer_norm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)

def forward(self, emb, **kwargs):
"""Decode, possibly stepwise."""
Expand Down
34 changes: 8 additions & 26 deletions eole/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import torch.nn as nn

from eole.encoders.encoder import EncoderBase
from eole.modules import MultiHeadedAttention
from eole.modules.multi_headed_attn import MultiHeadedAttention
from eole.modules.transformer_mlp import MLP

from eole.modules.rmsnorm import RMSNorm
from eole.constants import LayerNorm


class TransformerEncoderLayer(nn.Module):
Expand All @@ -26,18 +25,10 @@ def __init__(
running_config=None,
):
super(TransformerEncoderLayer, self).__init__()
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.parallel_residual = model_config.parallel_residual
self.dropout_p = getattr(running_config, "dropout", [0.0])[0]

self.input_layernorm = layernorm(
self.input_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
self.self_attn = MultiHeadedAttention(
Expand All @@ -47,7 +38,7 @@ def __init__(
attn_type="self",
)
self.dropout = nn.Dropout(self.dropout_p)
self.post_attention_layernorm = layernorm(
self.post_attention_layernorm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)
self.mlp = MLP(
Expand Down Expand Up @@ -120,18 +111,9 @@ def __init__(
]
)
# This is the Encoder out layer norm
if model_config.layer_norm == "standard":
self.layer_norm = nn.LayerNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
elif model_config.layer_norm == "rms":
self.layer_norm = RMSNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)
self.layer_norm = LayerNorm[model_config.layer_norm](
model_config.hidden_size, eps=model_config.norm_eps
)

@classmethod
def from_config(cls, model_config, running_config=None):
Expand Down
2 changes: 1 addition & 1 deletion eole/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from eole.encoders import str2enc
from eole.decoders import str2dec
from eole.constants import DefaultTokens
from eole.modules import Embeddings
from eole.modules.embeddings import Embeddings

from eole.models.model_saver import load_checkpoint
from eole.modules.estimator import FeedForward
Expand Down
25 changes: 0 additions & 25 deletions eole/modules/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion eole/modules/average_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor
from typing import Optional
from eole.modules.transformer_mlp import MLP
from eole.modules.transformer_mlp import ActivationFunction
from eole.constants import ActivationFunction


def cumulative_average_mask(
Expand Down
22 changes: 2 additions & 20 deletions eole/modules/transformer_mlp.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,11 @@
"""MLP network from "Attention is All You Need"."""

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.checkpoint import checkpoint
from torch.nn.utils import skip_init
from torch.distributed import all_reduce
from enum import Enum


class ActivationFunction(str, Enum):
relu = "relu"
gelu = "gelu"
silu = "silu"
gated_gelu = "gated-gelu"
gated_silu = "gated-silu"


# for silu, see: https://arxiv.org/pdf/2002.05202.pdf
ACTIVATION_FUNCTIONS = {
ActivationFunction.relu: F.relu,
ActivationFunction.gelu: F.gelu,
ActivationFunction.silu: F.silu,
ActivationFunction.gated_gelu: F.gelu,
ActivationFunction.gated_silu: F.silu,
}
from eole.constants import ACTIVATION_FUNCTIONS


class MLP(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion eole/tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_masked_global_attention(self):
enc_out = Variable(torch.randn(batch_size, src_len.max(), dim))
enc_final_hs = Variable(torch.randn(batch_size, dim))

attn = eole.modules.GlobalAttention(dim)
attn = eole.modules.global_attention.GlobalAttention(dim)

_, alignments = attn(enc_final_hs, enc_out, src_len=src_len)
# TODO: fix for pytorch 0.3
Expand Down
4 changes: 2 additions & 2 deletions eole/utils/cnn_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn
import torch.nn.init as init

import eole.modules
from eole.modules.weight_norm import WeightNormConv2d

SCALE_WEIGHT = 0.5**0.5

Expand All @@ -20,7 +20,7 @@ class GatedConv(nn.Module):

def __init__(self, input_size, width=3, dropout=0.2, nopad=False):
super(GatedConv, self).__init__()
self.conv = eole.modules.WeightNormConv2d(
self.conv = WeightNormConv2d(
input_size,
2 * input_size,
kernel_size=(width, 1),
Expand Down
Loading
0