8000 Add default LoRA target modules for Phi-2 by arnavgarg1 · Pull Request #3911 · ludwig-ai/ludwig · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add default LoRA target modules for Phi-2 #3911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions ludwig/schema/model_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ludwig.schema.trainer import ECDTrainerConfig
from ludwig.types import HyperoptConfigDict, ModelConfigDict
from ludwig.utils.data_utils import get_sanitized_feature_name
from ludwig.utils.llm_utils import get_context_len
from ludwig.utils.llm_utils import _PHI_MODELS, get_context_len

if TYPE_CHECKING:
from ludwig.schema.model_types.base import ModelConfig
Expand Down Expand Up @@ -315,8 +315,14 @@ def set_llm_parameters(config: "ModelConfig") -> None:

# HACK(Arnav): Set Mixtral target modules when using LoRA
# GitHub issue: https://github.com/ludwig-ai/ludwig/issues/3853
# PEFT PR: https://github.com/huggingface/peft/pull/1376
_set_mixtral_target_modules(config)

# HACK(Arnav): Set Phi-2 target modules when using LoRA
# GitHub issue: https://github.com/ludwig-ai/ludwig/issues/3910
# PEFT PR: https://github.com/huggingface/peft/pull/1375
_set_phi2_target_modules(config)


def _set_llm_tokenizers(config: "ModelConfig") -> None:
"""Sets the tokenizers for the LLM model to the pretrained model name or path. This ensures that they use the
Expand Down Expand Up @@ -421,9 +427,28 @@ def _set_mixtral_target_modules(config: "ModelConfig") -> None:
if config.adapter.type != "lora" or config.adapter.target_modules:
return

logger.info("Setting adapter target modules to ['q_proj', 'v_proj'] for Mixtral 7x8 base model with LoRA adapter.")
target_modules = ["q_proj", "v_proj"]

logger.info(f"Setting adapter target modules to {target_modules} for Mixtral 7x8 base model with LoRA adapter.")
config.adapter.target_modules = target_modules


def _set_phi2_target_modules(config: "ModelConfig") -> None:
"""If the base model is Phi-2, LoRA is enabled and the target modules are not set, set the target modules to
maximize performance."""
if config.base_model not in _PHI_MODELS:
return

if not config.adapter:
return

if config.adapter.type != "lora" or config.adapter.target_modules:
return

target_modules = ["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"]

config.adapter.target_modules = ["q_proj", "v_proj"]
logger.info(f"Setting adapter target modules to {target_modules} for Phi-2 base model with LoRA adapter.")
config.adapter.target_modules = target_modules


@DeveloperAPI
Expand Down
10 changes: 6 additions & 4 deletions ludwig/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@

logger = logging.getLogger(__name__)

transformers_436 = version.parse(transformers.__version__) >= version.parse("4.36.0")

FALLBACK_CONTEXT_LEN = 2048

transformers_436 = version.parse(transformers.__version__) >= version.parse("4.36.0")

# Phi models don't support "device_map='auto'" at model load time as of transformers 4.37.0.
_MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION = {
_PHI_MODELS = {
"susnato/phi-1_dev",
"susnato/phi-1_5_dev",
"susnato/phi-2",
Expand All @@ -38,6 +36,10 @@
"microsoft/phi-2",
}

_MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION = set()
# Phi models don't support "device_map='auto'" at model load time as of transformers 4.37.0.
_MODELS_WITH_DEVICE_MAP_AUTO_EXCLUSION.update(_PHI_MODELS)


@default_retry(tries=8)
def load_pretrained_from_config(
Expand Down
0