8000 int: Rename original `combiner_registry` to `combiner_config_registry`, update decorator name by ksbrar · Pull Request #3516 · ludwig-ai/ludwig · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

int: Rename original combiner_registry to combiner_config_registry, update decorator name #3516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from ludwig.schema import common_fields
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("comparator")
@register_combiner_config("comparator")
@ludwig_dataclass
class ComparatorCombinerConfig(BaseCombinerConfig):
"""Parameters for comparator combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from ludwig.schema import common_fields
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("concat")
@register_combiner_config("concat")
@ludwig_dataclass
class ConcatCombinerConfig(BaseCombinerConfig):
"""Parameters for concat combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/project_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("project_aggregate")
@register_combiner_config("project_aggregate")
@ludwig_dataclass
class ProjectAggregateCombinerConfig(BaseCombinerConfig):
type: str = schema_utils.ProtectedString(
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.sequence_concat import MAIN_SEQUENCE_FEATURE_DESCRIPTION
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.encoders.base import BaseEncoderConfig
from ludwig.schema.encoders.utils import EncoderDataclassField
from ludwig.schema.metadata import COMBINER_METADATA
Expand All @@ -19,7 +19,7 @@


@DeveloperAPI
@register_combiner("sequence")
@register_combiner_config("sequence")
@ludwig_dataclass
class SequenceCombinerConfig(BaseCombinerConfig):
"""Parameters for sequence combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/sequence_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass

Expand All @@ -19,7 +19,7 @@


@DeveloperAPI
@register_combiner("sequence_concat")
@register_combiner_config("sequence_concat")
@ludwig_dataclass
class SequenceConcatCombinerConfig(BaseCombinerConfig):
"""Parameters for sequence concat combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/tab_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.common_transformer_options import CommonTransformerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("tabtransformer")
@register_combiner_config("tabtransformer")
@ludwig_dataclass
class TabTransformerCombinerConfig(BaseCombinerConfig, CommonTransformerConfig):
"""Parameters for tab transformer combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/tabnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("tabnet")
@register_combiner_config("tabnet")
@ludwig_dataclass
class TabNetCombinerConfig(BaseCombinerConfig):
"""Parameters for tabnet combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.common_transformer_options import CommonTransformerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("transformer")
@register_combiner_config("transformer")
@ludwig_dataclass
class TransformerCombinerConfig(BaseCombinerConfig, CommonTransformerConfig):
"""Parameters for transformer combiner."""
Expand Down
18 changes: 9 additions & 9 deletions ludwig/schema/combiners/utils.py
F438
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,28 @@
DEFAULT_VALUE = "concat"
DESCRIPTION = "Select the combiner type."

combiner_registry = Registry[Type[BaseCombinerConfig]]()
combiner_config_registry = Registry[Type[BaseCombinerConfig]]()


@DeveloperAPI
def register_combiner(name: str):
def register_combiner_config(name: str):
def wrap(cls: Type[BaseCombinerConfig]):
combiner_registry[name] = cls
combiner_config_registry[name] = cls
return cls

return wrap


@DeveloperAPI
def get_combiner_registry():
return combiner_registry
return combiner_config_registry


@DeveloperAPI
def get_combiner_jsonschema():
"""Returns a JSON schema structured to only require a `type` key and then conditionally apply a corresponding
combiner's field constraints."""
combiner_types = sorted(list(combiner_registry.keys()))
combiner_types = sorted(list(combiner_config_registry.keys()))
parameter_metadata = convert_metadata_to_json(
ParameterMetadata.from_dict(
{
Expand Down Expand Up @@ -72,17 +72,17 @@ def get_combiner_descriptions():
Returns:
dict: A dictionary of combiner descriptions.
"""
return {k: convert_metadata_to_json(v[TYPE]) for k, v in COMBINER_METADATA.items() if k in combiner_registry}
return {k: convert_metadata_to_json(v[TYPE]) for k, v in COMBINER_METADATA.items() if k in combiner_config_registry}


@DeveloperAPI
def get_combiner_conds() -> List[Dict[str, Any]]:
"""Returns a list of if-then JSON clauses for each combiner type in `combiner_registry` and its properties'
constraints."""
combiner_types = sorted(list(combiner_registry.keys()))
combiner_types = sorted(list(combiner_config_registry.keys()))
conds = []
for combiner_type in combiner_types:
combiner_cls = combiner_registry[combiner_type]
combiner_cls = combiner_config_registry[combiner_type]
schema_cls = combiner_cls
combiner_schema = schema_utils.unload_jsonschema_from_marshmallow_class(schema_cls)
combiner_props = combiner_schema["properties"]
Expand All @@ -97,7 +97,7 @@ def __init__(self):
# For registration of all combiners
import ludwig.combiners.combiners # noqa

super().__init__(registry=combiner_registry, default_value=DEFAULT_VALUE, description=DESCRIPTION)
super().__init__(registry=combiner_config_registry, default_value=DEFAULT_VALUE, description=DESCRIPTION)

def get_schema_from_registry(self, key: str) -> Type[schema_utils.BaseMarshmallowConfig]:
return self.registry[key]
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/test_custom_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ludwig.modules.metric_modules import LossMetric, register_metric
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner as register_combiner_schema
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.decoders.base import BaseDecoderConfig
from ludwig.schema.decoders.utils import register_decoder_config
from ludwig.schema.encoders.base import BaseEncoderConfig
Expand Down Expand Up @@ -55,7 +55,7 @@ class CustomLossConfig(BaseLossConfig):
type: str = "custom_loss"


@register_combiner_schema("custom_combiner")
@register_combiner_config("custom_combiner")
@dataclass
class CustomTestCombinerConfig(BaseCombinerConfig):
type: str = "custom_combiner"
Expand Down
0