8000 Remove dependency from config to utils by lukemerrick · Pull Request #2034 · Lightning-AI/litgpt · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Remove dependency from config to utils #2034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@
import yaml
from typing_extensions import Self

from litgpt.utils import find_multiple

def find_multiple(n: int, k: int) -> int:
"""Utility function for finding the nearest value to n which is a multiple of k.

NOTE: We define this function in this module rather than `litgpt.utils` so that users can import
this file to do configuration manipulations in Python environments which do not include all the dependencies
demanded by `litgpt.utils`.
"""
assert k > 0
if n % k == 0:
return n
return n + k - (n % k)


@dataclass
Expand Down
7 changes: 0 additions & 7 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,6 @@ def find_resume_path(resume: Union[bool, Literal["auto"], Path], out_dir: Path)
return resume_path


def find_multiple(n: int, k: int) -> int:
assert k > 0
if n % k == 0:
return n
return n + k - (n % k)


def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
total = 0
for p in module.parameters():
Expand Down
11 changes: 11 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import litgpt.config as config_module
from litgpt import Config
from litgpt.config import find_multiple


def test_config():
Expand Down Expand Up @@ -103,3 +104,13 @@ def test_head_size(head_size):
config = Config(head_size)

assert config.head_size == head_size or config.n_embd // config.n_head


def test_find_multiple():
assert find_multiple(17, 5) == 20
assert find_multiple(30, 7) == 35
assert find_multiple(10, 2) == 10
assert find_multiple(5, 10) == 10
assert find_multiple(50254, 128) == 50304
assert find_multiple(50254, 256) == 50432
assert find_multiple(50254, 512) == 50688
11 changes: 0 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
find_multiple,
find_resume_path,
fix_and_load_json,
incremental_save,
Expand All @@ -45,16 +44,6 @@
)


def test_find_multiple():
assert find_multiple(17, 5) == 20
assert find_multiple(30, 7) == 35
assert find_multiple(10, 2) == 10
assert find_multiple(5, 10) == 10
assert find_multiple(50254, 128) == 50304
assert find_multiple(50254, 256) == 50432
assert find_multiple(50254, 512) == 50688


# match fails on windows. why did they have to use backslashes?
@_RunIf(skip_windows=True)
def test_check_valid_checkpoint_dir(tmp_path):
Expand Down
Loading
0