8000 Minor cleanup of oumi fetch by taenin · Pull Request #1463 · oumi-ai/oumi · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Minor cleanup of oumi fetch #1463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions src/oumi/cli/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
from pathlib import Path
from typing import Annotated, Optional

import requests
import typer
import yaml
from requests.exceptions import RequestException

from oumi.utils.logging import logger

CONTEXT_ALLOW_EXTRA_ARGS = {"allow_extra_args": True, "ignore_unknown_options": True}
CONFIG_FLAGS = ["--config", "-c"]
OUMI_FETCH_DIR = "~/.oumi/fetch"
OUMI_GITHUB_RAW = "https://raw.githubusercontent.com/oumi-ai/oumi/main"
_OUMI_PREFIX = "oumi://"


def parse_extra_cli_args(ctx: typer.Context) -> list[str]:
Expand Down Expand Up @@ -141,7 +146,7 @@ def set_log_level(level: Optional[LogLevel]):
]


def resolve_oumi_prefix(
def _resolve_oumi_prefix(
config_path: str, output_dir: Optional[Path] = None
) -> tuple[str, Path]:
"""Resolves oumi:// prefix and determines output directory.
Expand All @@ -153,9 +158,8 @@ def resolve_oumi_prefix(
Returns:
tuple[str, Path]: (cleaned path, output directory)
"""
oumi_prefix = "oumi://"
if config_path.lower().startswith(oumi_prefix):
config_path = config_path[len(oumi_prefix) :]
if config_path.lower().startswith(_OUMI_PREFIX):
config_path = config_path[len(_OUMI_PREFIX) :]

config_dir = output_dir or os.environ.get("OUMI_DIR") or OUMI_FETCH_DIR
config_dir = Path(config_dir).expanduser()
Expand All @@ -165,22 +169,54 @@ def resolve_oumi_prefix(


def resolve_and_fetch_config(
config_path: str, output_dir: Optional[Path] = None
config_path: str, output_dir: Optional[Path] = None, force: bool = True
) -> Path:
"""Resolve oumi:// prefix and fetch config if needed.

Args:
config_path: Original config path that may contain oumi:// prefix
output_dir: Optional override for output directory
force: Whether to overwrite an existing config

Returns:
Path: Local path to the config file
"""
if not config_path.lower().startswith("oumi://"):
if not config_path.lower().startswith(_OUMI_PREFIX):
return Path(config_path)

from oumi.cli.fetch import fetch

fetch(config_path, output_dir)
# Remove oumi:// prefix if present
new_config_path, config_dir = _resolve_oumi_prefix(config_path, output_dir)

return Path(config_path)
try:
# Check destination first
local_path = (config_dir or Path(OUMI_FETCH_DIR).expanduser()) / new_config_path
if local_path.exists() and not force:
msg = f"Config already exists at {local_path}. Use --force to overwrite"
logger.error(msg)
raise RuntimeError(msg)

# Fetch from GitHub
github_url = f"{OUMI_GITHUB_RAW}/{new_config_path.lstrip('/')}"
response = requests.get(github_url)
response.raise_for_status()
config_content = response.text

# Validate YAML
yaml.safe_load(config_content)

# Save to destination
if local_path.exists():
logger.warning(f"Overwriting existing config at {local_path}")
local_path.parent.mkdir(parents=True, exist_ok=True)

with open(local_path, "w") as f:
f.write(config_content)
logger.info(f"Successfully downloaded config to {local_path}")
except RequestException as e:
logger.error(f"Failed to download config from GitHub: {e}")
raise
except yaml.YAMLError:
logger.error("Invalid YAML configuration")
raise

return Path(local_path)
48 changes: 7 additions & 41 deletions src/oumi/cli/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@
from pathlib import Path
from typing import Annotated, Optional

import requests
import typer
import yaml
from requests.exceptions import RequestException

from oumi.cli.cli_utils import OUMI_FETCH_DIR, resolve_oumi_prefix
from oumi.cli.cli_utils import _OUMI_PREFIX, resolve_and_fetch_config
from oumi.utils.logging import logger

OUMI_GITHUB_RAW = "https://raw.githubusercontent.com/oumi-ai/oumi/main"
Expand All @@ -30,7 +27,8 @@ def fetch(
config_path: Annotated[
str,
typer.Argument(
help="Path to config (e.g. oumi://smollm/inference/135m_infer.yaml)"
help="Path to config "
"(e.g. oumi://configs/recipes/smollm/inference/135m_infer.yaml)"
),
],
output_dir: Annotated[
Expand All @@ -49,39 +47,7 @@ def fetch(
] = False,
) -> None:
"""Fetch configuration files from GitHub repository."""
# Remove oumi:// prefix if present
config_path, config_dir = resolve_oumi_prefix(config_path, output_dir)

try:
# Check destination first
local_path = (config_dir or Path(OUMI_FETCH_DIR).expanduser()) / config_path
if local_path.exists() and not force:
msg = f"Config already exists at {local_path}. Use --force to overwrite"
logger.error(msg)
typer.echo(msg, err=True)
raise typer.Exit(code=1)

# Fetch from GitHub
github_url = f"{OUMI_GITHUB_RAW}/{config_path.lstrip('/')}"
response = requests.get(github_url)
response.raise_for_status()
config_content = response.text

# Validate YAML
yaml.safe_load(config_content)

# Save to destination
if local_path.exists():
logger.warning(f"Overwriting existing config at {local_path}")
local_path.parent.mkdir(parents=True, exist_ok=True)

with open(local_path, "w") as f:
f.write(config_content)
logger.info(f"Successfully downloaded config to {local_path}")

except RequestException as e:
logger.error(f"Failed to download config from GitHub: {e}")
raise typer.Exit(1)
except yaml.YAMLError:
logger.error("Invalid YAML configuration")
raise typer.Exit(1)
if not config_path.lower().startswith(_OUMI_PREFIX):
logger.info(f"Prepending {_OUMI_PREFIX} to config path")
config_path = _OUMI_PREFIX + config_path
_ = resolve_and_fetch_config(config_path, output_dir, force)
17 changes: 5 additions & 12 deletions src/oumi/cli/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
from pathlib import Path
from typing import Annotated, Final, Optional

import typer
Expand All @@ -33,16 +32,6 @@ def infer(
help="Path to the configuration file for inference.",
),
],
output_dir: Annotated[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some other name for the flag that you think will be more clear? Also fine to delete if you feel that it's not likely to be set

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for all entrypoints excepts oumi fetch we should use the env_var to determine this value. No need for the param here.

Optional[Path],
typer.Option(
"--output-dir",
help=(
"Directory to save configs "
"(defaults to OUMI_DIR env var or ~/.oumi/fetch)"
),
),
] = None,
interactive: Annotated[
bool,
typer.Option("-i", "--interactive", help="Run in an interactive session."),
Expand Down Expand Up @@ -87,7 +76,11 @@ def infer(
"""
extra_args = cli_utils.parse_extra_cli_args(ctx)

config = str(cli_utils.resolve_and_fetch_config(config, output_dir))
config = str(
cli_utils.resolve_and_fetch_config(
config,
)
)

# Delayed imports
from oumi import infer as oumi_infer
Expand Down
Loading
0