From acdadf5ae1413593cf93b30ac78fc8438f985043 Mon Sep 17 00:00:00 2001 From: William Zeng <10782997+wizeng23@users.noreply.github.com> Date: Sun, 6 Apr 2025 17:07:18 -0700 Subject: [PATCH 01/15] Add `vllm` to `gpu` optional dependencies (#1614) --- configs/examples/bulk_inference/gcp_job.yaml | 2 +- configs/examples/grpo_tldr/gcp_job.yaml | 2 +- .../examples/letter_counting/evaluation/eval.yaml | 2 +- .../examples/letter_counting/evaluation/gcp_job.yaml | 2 +- configs/examples/letter_counting/grpo/gcp_job.yaml | 2 +- .../recipes/llama3_1/inference/8b_rvllm_infer.yaml | 2 +- .../recipes/llama3_2/inference/1b_vllm_infer.yaml | 2 +- .../recipes/llama3_2/inference/3b_vllm_infer.yaml | 2 +- .../recipes/llama3_3/inference/70b_vllm_infer.yaml | 2 +- .../llama3_2_vision/inference/11b_rvllm_infer.yaml | 2 +- .../llama3_2_vision/inference/11b_vllm_infer.yaml | 2 +- .../vision/llava_7b/inference/vllm_infer.yaml | 2 +- .../recipes/vision/phi3/inference/vllm_infer.yaml | 2 +- .../recipes/vision/phi4/inference/vllm_infer.yaml | 2 +- .../vision/qwen2_5_vl_3b/inference/vllm_infer.yaml | 2 +- .../vision/qwen2_vl_2b/inference/vllm_infer.yaml | 2 +- docs/user_guides/infer/inference_engines.md | 3 +++ ...ustom Evaluation (Hallucination Classifier).ipynb | 4 ++-- notebooks/Oumi - Distill a Large Model.ipynb | 2 +- notebooks/Oumi - Finetuning Tutorial.ipynb | 2 +- notebooks/Oumi - MiniMath-R1-1.5B.ipynb | 2 +- .../Oumi - Using vLLM Engine for Inference.ipynb | 2 +- notebooks/Oumi - Vision Language Models.ipynb | 4 +--- pyproject.toml | 4 ++-- src/oumi/launcher/clusters/polaris_cluster.py | 2 +- tests/unit/launcher/clusters/test_polaris_cluster.py | 12 ++++++------ 26 files changed, 35 insertions(+), 34 deletions(-) diff --git a/configs/examples/bulk_inference/gcp_job.yaml b/configs/examples/bulk_inference/gcp_job.yaml index 4d28311d7..e056b7416 100644 --- a/configs/examples/bulk_inference/gcp_job.yaml +++ b/configs/examples/bulk_inference/gcp_job.yaml @@ -46,7 +46,7 @@ envs: setup: | set -e - pip install uv && uv pip install oumi[gpu] vllm>=0.7.3 + pip install uv && uv pip install oumi[gpu] run: | set -e # Exit if any command failed. diff --git a/configs/examples/grpo_tldr/gcp_job.yaml b/configs/examples/grpo_tldr/gcp_job.yaml index 1df040fbc..204ba1f50 100644 --- a/configs/examples/grpo_tldr/gcp_job.yaml +++ b/configs/examples/grpo_tldr/gcp_job.yaml @@ -31,7 +31,7 @@ envs: setup: | set -e - pip install uv && uv pip install oumi[gpu] "vllm>=0.7.3,<0.8.0" + pip install uv && uv pip install oumi[gpu] pip install -U flash-attn --no-build-isolation run: | diff --git a/configs/examples/letter_counting/evaluation/eval.yaml b/configs/examples/letter_counting/evaluation/eval.yaml index 3b060d505..23d874ffc 100644 --- a/configs/examples/letter_counting/evaluation/eval.yaml +++ b/configs/examples/letter_counting/evaluation/eval.yaml @@ -1,7 +1,7 @@ # Config to eval an LLM's ability to count letters in words. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # - Log into HF: `huggingface-cli login` # # Usage: diff --git a/configs/examples/letter_counting/evaluation/gcp_job.yaml b/configs/examples/letter_counting/evaluation/gcp_job.yaml index 24baa970f..3b4772e28 100644 --- a/configs/examples/letter_counting/evaluation/gcp_job.yaml +++ b/configs/examples/letter_counting/evaluation/gcp_job.yaml @@ -36,7 +36,7 @@ envs: setup: | set -e - pip install uv && uv pip install oumi[gpu,evaluation] "vllm>=0.7.3,<0.8.0" + pip install uv && uv pip install oumi[gpu,evaluation] run: | set -e # Exit if any command failed. diff --git a/configs/examples/letter_counting/grpo/gcp_job.yaml b/configs/examples/letter_counting/grpo/gcp_job.yaml index c3f092574..967071826 100644 --- a/configs/examples/letter_counting/grpo/gcp_job.yaml +++ b/configs/examples/letter_counting/grpo/gcp_job.yaml @@ -34,7 +34,7 @@ envs: setup: | set -e # vLLM needed for vLLM-powered generation during GRPO training. - pip install uv && uv pip install oumi[gpu] "vllm>=0.7.3,<0.8.0" + pip install uv && uv pip install oumi[gpu] pip install -U flash-attn --no-build-isolation run: | diff --git a/configs/recipes/llama3_1/inference/8b_rvllm_infer.yaml b/configs/recipes/llama3_1/inference/8b_rvllm_infer.yaml index 5422726a0..9daf24bdf 100644 --- a/configs/recipes/llama3_1/inference/8b_rvllm_infer.yaml +++ b/configs/recipes/llama3_1/inference/8b_rvllm_infer.yaml @@ -1,7 +1,7 @@ # Inference config for Llama 8B Instruct. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # - Log into HF: `huggingface-cli login` # - Request access to Llama 3.1: https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct # diff --git a/configs/recipes/llama3_2/inference/1b_vllm_infer.yaml b/configs/recipes/llama3_2/inference/1b_vllm_infer.yaml index 092e16d6d..24f6bbcb1 100644 --- a/configs/recipes/llama3_2/inference/1b_vllm_infer.yaml +++ b/configs/recipes/llama3_2/inference/1b_vllm_infer.yaml @@ -1,7 +1,7 @@ # Inference config for Llama 1B Instruct. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # - Log into HF: `huggingface-cli login` # - Request access to Llama 3.2: https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct # diff --git a/configs/recipes/llama3_2/inference/3b_vllm_infer.yaml b/configs/recipes/llama3_2/inference/3b_vllm_infer.yaml index 3f7a54a51..de348a7ea 100644 --- a/configs/recipes/llama3_2/inference/3b_vllm_infer.yaml +++ b/configs/recipes/llama3_2/inference/3b_vllm_infer.yaml @@ -1,7 +1,7 @@ # Inference config for Llama 3B Instruct. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # - Log into HF: `huggingface-cli login` # - Request access to Llama 3.2: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct # diff --git a/configs/recipes/llama3_3/inference/70b_vllm_infer.yaml b/configs/recipes/llama3_3/inference/70b_vllm_infer.yaml index c2ef24aff..2024ee3b8 100644 --- a/configs/recipes/llama3_3/inference/70b_vllm_infer.yaml +++ b/configs/recipes/llama3_3/inference/70b_vllm_infer.yaml @@ -1,7 +1,7 @@ # Inference config for Llama 3.3 70B Instruct with VLLM. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # - Log into HF: `huggingface-cli login` # - Request access to Llama 3.3: https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct # diff --git a/configs/recipes/vision/llama3_2_vision/inference/11b_rvllm_infer.yaml b/configs/recipes/vision/llama3_2_vision/inference/11b_rvllm_infer.yaml index a6afd4214..1b3f3c074 100644 --- a/configs/recipes/vision/llama3_2_vision/inference/11b_rvllm_infer.yaml +++ b/configs/recipes/vision/llama3_2_vision/inference/11b_rvllm_infer.yaml @@ -1,7 +1,7 @@ # Remote vLLM inference config for Llama 3.2 11B Vision Instruct. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # - Log into HF: `huggingface-cli login` # - Request access to Llama 3.2: https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct # diff --git a/configs/recipes/vision/llama3_2_vision/inference/11b_vllm_infer.yaml b/configs/recipes/vision/llama3_2_vision/inference/11b_vllm_infer.yaml index 160f91d8b..fc80e1678 100644 --- a/configs/recipes/vision/llama3_2_vision/inference/11b_vllm_infer.yaml +++ b/configs/recipes/vision/llama3_2_vision/inference/11b_vllm_infer.yaml @@ -1,7 +1,7 @@ # vLLM inference config for Llama 3.2 11B Vision Instruct. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # - Log into HF: `huggingface-cli login` # - Request access to Llama 3.2: https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct # diff --git a/configs/recipes/vision/llava_7b/inference/vllm_infer.yaml b/configs/recipes/vision/llava_7b/inference/vllm_infer.yaml index dcbb3231a..e2525a756 100644 --- a/configs/recipes/vision/llava_7b/inference/vllm_infer.yaml +++ b/configs/recipes/vision/llava_7b/inference/vllm_infer.yaml @@ -1,7 +1,7 @@ # Llava 7B vLLM inference config. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # # Usage: # oumi infer -i -c configs/recipes/vision/llava_7b/inference/vllm_infer.yaml \ diff --git a/configs/recipes/vision/phi3/inference/vllm_infer.yaml b/configs/recipes/vision/phi3/inference/vllm_infer.yaml index af59fd1b7..4d41435db 100644 --- a/configs/recipes/vision/phi3/inference/vllm_infer.yaml +++ b/configs/recipes/vision/phi3/inference/vllm_infer.yaml @@ -1,7 +1,7 @@ # Phi3 vision vLLM inference config. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # # Usage: # oumi infer -i -c configs/recipes/vision/phi3/inference/vllm_infer.yaml \ diff --git a/configs/recipes/vision/phi4/inference/vllm_infer.yaml b/configs/recipes/vision/phi4/inference/vllm_infer.yaml index 6a9d25092..860959d6f 100644 --- a/configs/recipes/vision/phi4/inference/vllm_infer.yaml +++ b/configs/recipes/vision/phi4/inference/vllm_infer.yaml @@ -1,7 +1,7 @@ # Phi-4-multimodal-instruct 5.6B vLLM inference config. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # - Run `pip install -U flash-attn --no-build-isolation` # # Usage: diff --git a/configs/recipes/vision/qwen2_5_vl_3b/inference/vllm_infer.yaml b/configs/recipes/vision/qwen2_5_vl_3b/inference/vllm_infer.yaml index b94642a14..1fc90dc1a 100644 --- a/configs/recipes/vision/qwen2_5_vl_3b/inference/vllm_infer.yaml +++ b/configs/recipes/vision/qwen2_5_vl_3b/inference/vllm_infer.yaml @@ -1,7 +1,7 @@ # Qwen 2.5 VL 3B vLLM inference config. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # # !Important! this model also requires the latest (dev) version of transformers. # Please read more at qwen2_5_vl_3b/README.md diff --git a/configs/recipes/vision/qwen2_vl_2b/inference/vllm_infer.yaml b/configs/recipes/vision/qwen2_vl_2b/inference/vllm_infer.yaml index 49e2eb27a..79cb22256 100644 --- a/configs/recipes/vision/qwen2_vl_2b/inference/vllm_infer.yaml +++ b/configs/recipes/vision/qwen2_vl_2b/inference/vllm_infer.yaml @@ -1,7 +1,7 @@ # vLLM inference config for Qwen2 VL 2B Instruct. # # Requirements: -# - Run `pip install vllm` +# - Run `pip install oumi[gpu]` # # Usage: # oumi infer -i -c configs/recipes/vision/qwen2_vl_2b/inference/vllm_infer.yaml \ diff --git a/docs/user_guides/infer/inference_engines.md b/docs/user_guides/infer/inference_engines.md index 87052572f..81a05b5f2 100644 --- a/docs/user_guides/infer/inference_engines.md +++ b/docs/user_guides/infer/inference_engines.md @@ -112,6 +112,9 @@ First, make sure to install the vLLM package: ```bash pip install vllm +# Alternatively, install all Oumi GPU dependencies, which takes care of installing a +# vLLM version compatible with your current Oumi version. +pip install oumi[gpu] ``` **Basic Usage** diff --git a/notebooks/Oumi - Build your own Custom Evaluation (Hallucination Classifier).ipynb b/notebooks/Oumi - Build your own Custom Evaluation (Hallucination Classifier).ipynb index b57018eff..40e58a31a 100644 --- a/notebooks/Oumi - Build your own Custom Evaluation (Hallucination Classifier).ipynb +++ b/notebooks/Oumi - Build your own Custom Evaluation (Hallucination Classifier).ipynb @@ -43,7 +43,7 @@ "\n", "### Oumi Installation\n", "\n", - "First, let's install Oumi and vLLM. You can find more detailed instructions about Oumi installation [here](https://oumi.ai/docs/en/latest/get_started/installation.html)." + "First, let's install Oumi. You can find more detailed instructions about Oumi installation [here](https://oumi.ai/docs/en/latest/get_started/installation.html)." ] }, { @@ -52,7 +52,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install oumi vllm" + "%pip install oumi" ] }, { diff --git a/notebooks/Oumi - Distill a Large Model.ipynb b/notebooks/Oumi - Distill a Large Model.ipynb index 8ba0bd56a..e71714d4e 100644 --- a/notebooks/Oumi - Distill a Large Model.ipynb +++ b/notebooks/Oumi - Distill a Large Model.ipynb @@ -66,7 +66,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install oumi[gpu] vllm" + "%pip install oumi[gpu]" ] }, { diff --git a/notebooks/Oumi - Finetuning Tutorial.ipynb b/notebooks/Oumi - Finetuning Tutorial.ipynb index d65c80770..eeb03aec9 100644 --- a/notebooks/Oumi - Finetuning Tutorial.ipynb +++ b/notebooks/Oumi - Finetuning Tutorial.ipynb @@ -60,7 +60,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install oumi[gpu] vllm" + "%pip install oumi[gpu]" ] }, { diff --git a/notebooks/Oumi - MiniMath-R1-1.5B.ipynb b/notebooks/Oumi - MiniMath-R1-1.5B.ipynb index 18974bcb1..dea5e35cf 100644 --- a/notebooks/Oumi - MiniMath-R1-1.5B.ipynb +++ b/notebooks/Oumi - MiniMath-R1-1.5B.ipynb @@ -63,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install oumi[gpu] vllm" + "%pip install oumi[gpu]" ] }, { diff --git a/notebooks/Oumi - Using vLLM Engine for Inference.ipynb b/notebooks/Oumi - Using vLLM Engine for Inference.ipynb index 0e4ba9c93..c9a4fdf4e 100644 --- a/notebooks/Oumi - Using vLLM Engine for Inference.ipynb +++ b/notebooks/Oumi - Using vLLM Engine for Inference.ipynb @@ -61,7 +61,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install oumi[gpu] vllm" + "%pip install oumi[gpu]" ] }, { diff --git a/notebooks/Oumi - Vision Language Models.ipynb b/notebooks/Oumi - Vision Language Models.ipynb index 11ab3cb63..0abbe652f 100644 --- a/notebooks/Oumi - Vision Language Models.ipynb +++ b/notebooks/Oumi - Vision Language Models.ipynb @@ -431,9 +431,7 @@ "engine: NATIVE \n", "# Let's use the `native` engine (i.e., the underlying machine's default)\n", "# for inference. \n", - "# You can also consider VLLM, if are working with GPU for much faster inference. \n", - "# To install an Oumi tested/compatible version, use:\n", - "# pip install \"vllm>=0.7.3,<0.8.0\"" + "# You can also consider VLLM, if are working with GPU for much faster inference. " ] }, { diff --git a/pyproject.toml b/pyproject.toml index a40329001..87dc95013 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,11 +111,12 @@ docs = [ "sphinxcontrib-typer", # Allows us to include typer CLI in the docs ] -# Dependencies that require a GPU to install +# Useful dependencies when running on GPU gpu = [ "liger-kernel>=0.5.0,<0.6", "nvidia-ml-py>=12.560.30,<12.561", "bitsandbytes>=0.45.0,<0.46", # Used for QLora, and PagedAdam implementation + "vllm>=0.7.3,<0.8.0", # For VLLMInferenceEngine ] # Targets for supported cloud providers @@ -163,7 +164,6 @@ ci_cpu = [ # gpu actions runner, so we skip it for now ci_gpu = [ "oumi[dev,docs,gcp,gpu]", - "vllm>=0.7.3,<0.8.0", "alpaca-eval>=0.6.6,<0.7", ] diff --git a/src/oumi/launcher/clusters/polaris_cluster.py b/src/oumi/launcher/clusters/polaris_cluster.py index c445d33c2..eb06ea890 100644 --- a/src/oumi/launcher/clusters/polaris_cluster.py +++ b/src/oumi/launcher/clusters/polaris_cluster.py @@ -257,7 +257,7 @@ def run_job(self, job: JobConfig) -> JobStatus: "if ! command -v uv >/dev/null 2>&1; then", "pip install -U uv", "fi", - "pip install -e '.[gpu]' vllm", # TODO Re-enable uv OPE-670 + "pip install -e '.[gpu]'", # TODO Re-enable uv OPE-670 ] self._client.run_commands(install_cmds) # Copy all file mounts. diff --git a/tests/unit/launcher/clusters/test_polaris_cluster.py b/tests/unit/launcher/clusters/test_polaris_cluster.py index 5b0421dbf..89e57a14e 100644 --- a/tests/unit/launcher/clusters/test_polaris_cluster.py +++ b/tests/unit/launcher/clusters/test_polaris_cluster.py @@ -367,7 +367,7 @@ def test_polaris_cluster_run_job(mock_datetime, mock_polaris_client): "if ! command -v uv >/dev/null 2>&1; then", "pip install -U uv", "fi", - "pip install -e '.[gpu]' vllm", + "pip install -e '.[gpu]'", ] ), call( @@ -469,7 +469,7 @@ def test_polaris_cluster_run_job_with_conda_setup(mock_datetime, mock_polaris_cl "if ! command -v uv >/dev/null 2>&1; then", "pip install -U uv", "fi", - "pip install -e '.[gpu]' vllm", + "pip install -e '.[gpu]'", ] ), call( @@ -570,7 +570,7 @@ def test_polaris_cluster_run_job_no_name(mock_datetime, mock_polaris_client): "if ! command -v uv >/dev/null 2>&1; then", "pip install -U uv", "fi", - "pip install -e '.[gpu]' vllm", + "pip install -e '.[gpu]'", ] ), call( @@ -659,7 +659,7 @@ def test_polaris_cluster_run_job_no_mounts(mock_datetime, mock_polaris_client): "if ! command -v uv >/dev/null 2>&1; then", "pip install -U uv", "fi", - "pip install -e '.[gpu]' vllm", + "pip install -e '.[gpu]'", ] ), call( @@ -750,7 +750,7 @@ def test_polaris_cluster_run_job_no_pbs(mock_datetime, mock_polaris_client): "if ! command -v uv >/dev/null 2>&1; then", "pip install -U uv", "fi", - "pip install -e '.[gpu]' vllm", + "pip install -e '.[gpu]'", ] ), call( @@ -833,7 +833,7 @@ def test_polaris_cluster_run_job_no_setup(mock_datetime, mock_polaris_client): "if ! command -v uv >/dev/null 2>&1; then", "pip install -U uv", "fi", - "pip install -e '.[gpu]' vllm", + "pip install -e '.[gpu]'", ] ), call( From a39686f52d0f6b6922a913aa3d9bb16d305fdc01 Mon Sep 17 00:00:00 2001 From: William Zeng <10782997+wizeng23@users.noreply.github.com> Date: Mon, 7 Apr 2025 00:05:56 -0700 Subject: [PATCH 02/15] [HallOumi] Update inference notebook (#1613) --- .../halloumi_inference_notebook.ipynb | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/configs/projects/halloumi/halloumi_inference_notebook.ipynb b/configs/projects/halloumi/halloumi_inference_notebook.ipynb index c41b13b23..44a3fd636 100644 --- a/configs/projects/halloumi/halloumi_inference_notebook.ipynb +++ b/configs/projects/halloumi/halloumi_inference_notebook.ipynb @@ -51,7 +51,23 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install oumi" + "%pip install git+https://github.com/oumi-ai/oumi.git" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you're running this notebook on a CUDA-compatible GPU and want to use vLLM for inference, make sure to install it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install \"vllm>=0.7.3,<0.8.0\"" ] }, { @@ -333,6 +349,7 @@ "local_config_str = \"\"\"\n", "model:\n", " model_name: \"oumi-ai/HallOumi-8B\"\n", + " model_max_length: 8192\n", " trust_remote_code: true\n", "\n", "generation:\n", From 14f6c7379e0cf2a9d5347d7b647329dae8b0c9df Mon Sep 17 00:00:00 2001 From: Matthew Persons Date: Mon, 7 Apr 2025 13:10:11 -0700 Subject: [PATCH 03/15] Update llama4 GCP jobs for non-dev environments. (#1621) --- configs/recipes/llama4/sft/scout_base_full/gcp_job.yaml | 2 +- configs/recipes/llama4/sft/scout_instruct_full/gcp_job.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/recipes/llama4/sft/scout_base_full/gcp_job.yaml b/configs/recipes/llama4/sft/scout_base_full/gcp_job.yaml index da63abc99..37040ac2d 100644 --- a/configs/recipes/llama4/sft/scout_base_full/gcp_job.yaml +++ b/configs/recipes/llama4/sft/scout_base_full/gcp_job.yaml @@ -57,7 +57,7 @@ run: | set -x oumi distributed torchrun \ -m oumi train \ - -c configs/recipes/llama4/sft/scout_base_full/train.yaml \ + -c oumi://configs/recipes/llama4/sft/scout_base_full/train.yaml \ --training.run_name "${OUMI_RUN_NAME}.${SKYPILOT_TASK_ID}" echo "Node ${SKYPILOT_NODE_RANK} is all done!" diff --git a/configs/recipes/llama4/sft/scout_instruct_full/gcp_job.yaml b/configs/recipes/llama4/sft/scout_instruct_full/gcp_job.yaml index 37823d247..2b1395510 100644 --- a/configs/recipes/llama4/sft/scout_instruct_full/gcp_job.yaml +++ b/configs/recipes/llama4/sft/scout_instruct_full/gcp_job.yaml @@ -57,7 +57,7 @@ run: | set -x oumi distributed torchrun \ -m oumi train \ - -c configs/recipes/llama4/sft/scout_instruct_full/train.yaml \ + -c oumi://configs/recipes/llama4/sft/scout_instruct_full/train.yaml \ --training.run_name "${OUMI_RUN_NAME}.${SKYPILOT_TASK_ID}" echo "Node ${SKYPILOT_NODE_RANK} is all done!" From 6539645184b010a116bc8faedf2254b2f7fbb1e6 Mon Sep 17 00:00:00 2001 From: William Zeng <10782997+wizeng23@users.noreply.github.com> Date: Mon, 7 Apr 2025 23:53:59 -0700 Subject: [PATCH 04/15] Update transformers to 4.51.0 (#1620) --- .../recipes/vision/qwen2_5_vl_3b/README.md | 20 ------------------- .../vision/qwen2_5_vl_3b/inference/infer.yaml | 5 ----- .../qwen2_5_vl_3b/inference/vllm_infer.yaml | 3 --- .../vision/qwen2_5_vl_3b/sft/full/train.yaml | 3 --- .../vision/qwen2_5_vl_3b/sft/lora/train.yaml | 3 --- pyproject.toml | 2 +- tests/e2e/test_eval_e2e.py | 1 + tests/unit/builders/test_models.py | 2 +- 8 files changed, 3 insertions(+), 36 deletions(-) diff --git a/configs/recipes/vision/qwen2_5_vl_3b/README.md b/configs/recipes/vision/qwen2_5_vl_3b/README.md index 81ebd2643..0c72ceffb 100644 --- a/configs/recipes/vision/qwen2_5_vl_3b/README.md +++ b/configs/recipes/vision/qwen2_5_vl_3b/README.md @@ -2,23 +2,3 @@ Configs for the **`Qwen2.5-VL`** 3B model. 🔗 **Reference:** [Qwen2.5-VL-3B-Instruct on Hugging Face](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) - ---- - -❗ **Important Note** -As of **February 2025**, `Qwen2.5-VL` is integrated into the **latest** `transformers` _dev_ version. - -⚠️ **Earlier versions may cause a runtime error:** -KeyError: ‘qwen2_5_vl’ - -Oumi has successfully tested this integration with: -- **SFT training** -- **Native inference** using **`transformers 4.49.0.dev0`** - -To update `transformers` to this version, run: - -```sh -pip install git+https://github.com/huggingface/transformers.git -``` - -⚠️ Caution: This upgrade may break other Oumi utilities. Proceed carefully. diff --git a/configs/recipes/vision/qwen2_5_vl_3b/inference/infer.yaml b/configs/recipes/vision/qwen2_5_vl_3b/inference/infer.yaml index 46aeee2eb..c41dc42fd 100644 --- a/configs/recipes/vision/qwen2_5_vl_3b/inference/infer.yaml +++ b/configs/recipes/vision/qwen2_5_vl_3b/inference/infer.yaml @@ -1,14 +1,9 @@ # Qwen 2.5 VL 3B inference config. # -# Requirements: -# !Important! this model requires the latest (dev) version of transformers. -# Please read more at qwen2_5_vl_3b/README.md -# # Usage: # oumi infer -i -c configs/recipes/vision/qwen2_5_vl_3b/inference/infer.yaml \ # --image "tests/testdata/images/the_great_wave_off_kanagawa.jpg" # -# # See Also: # - Documentation: https://oumi.ai/docs/en/latest/user_guides/infer/infer.html # - Config class: oumi.core.configs.InferenceConfig diff --git a/configs/recipes/vision/qwen2_5_vl_3b/inference/vllm_infer.yaml b/configs/recipes/vision/qwen2_5_vl_3b/inference/vllm_infer.yaml index 1fc90dc1a..b80ef053b 100644 --- a/configs/recipes/vision/qwen2_5_vl_3b/inference/vllm_infer.yaml +++ b/configs/recipes/vision/qwen2_5_vl_3b/inference/vllm_infer.yaml @@ -3,9 +3,6 @@ # Requirements: # - Run `pip install oumi[gpu]` # -# !Important! this model also requires the latest (dev) version of transformers. -# Please read more at qwen2_5_vl_3b/README.md -# # Usage: # oumi infer -i -c configs/recipes/vision/qwen2_5_vl_3b/inference/vllm_infer.yaml \ # --image "tests/testdata/images/the_great_wave_off_kanagawa.jpg" diff --git a/configs/recipes/vision/qwen2_5_vl_3b/sft/full/train.yaml b/configs/recipes/vision/qwen2_5_vl_3b/sft/full/train.yaml index a971f894d..5b2ff537a 100644 --- a/configs/recipes/vision/qwen2_5_vl_3b/sft/full/train.yaml +++ b/configs/recipes/vision/qwen2_5_vl_3b/sft/full/train.yaml @@ -4,9 +4,6 @@ # - Log into WandB (`wandb login`) or disable `enable_wandb` # - (optional) If you want to use flash attention, run `pip install -U flash-attn --no-build-isolation` # -# !Important! this model requires the latest (dev) version of transformers. -# Please read more at qwen2_5_vl_3b/README.md -# # Usage: # oumi train -c configs/recipes/vision/qwen2_5_vl_3b/sft/full/train.yaml # diff --git a/configs/recipes/vision/qwen2_5_vl_3b/sft/lora/train.yaml b/configs/recipes/vision/qwen2_5_vl_3b/sft/lora/train.yaml index 51289924a..dd5ef8e10 100644 --- a/configs/recipes/vision/qwen2_5_vl_3b/sft/lora/train.yaml +++ b/configs/recipes/vision/qwen2_5_vl_3b/sft/lora/train.yaml @@ -4,9 +4,6 @@ # - Log into WandB (`wandb login`) or disable `enable_wandb` # - (optional) If you want to use flash attention, run `pip install -U flash-attn --no-build-isolation` # -# !Important! this model requires the latest (dev) version of transformers. -# Please read more at qwen2_5_vl_3b/README.md -# # Usage: # oumi train -c configs/recipes/vision/qwen2_5_vl_3b/sft/lora/train.yaml # diff --git a/pyproject.toml b/pyproject.toml index 87dc95013..ee3b3bf00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ dependencies = [ "tqdm", # Llama Vision attention is broken as late as 4.48.2 if gradient checkpointing is # enabled. See OPE-875 and https://github.com/huggingface/transformers/issues/36040. - "transformers>=4.49.0,<4.50", + "transformers>=4.51.0,<4.52", # >=0.14.0 is needed for GRPOTrainer. "trl>=0.15.0,<0.16", "typer", # Used by CLI diff --git a/tests/e2e/test_eval_e2e.py b/tests/e2e/test_eval_e2e.py index 001a0693f..f9828b1d0 100644 --- a/tests/e2e/test_eval_e2e.py +++ b/tests/e2e/test_eval_e2e.py @@ -298,6 +298,7 @@ def test_eval_multimodal_1gpu_24gb(test_config: EvalTestConfig, tmp_path: Path): / "eval.yaml" ), num_samples=20, + use_simple_oumi_evaluate_command=True, ), EvalTestConfig( test_name="eval_text_deepseek_r1_distill_llama70b_multi_gpu", diff --git a/tests/unit/builders/test_models.py b/tests/unit/builders/test_models.py index 1df2d92c9..03c2afcd1 100644 --- a/tests/unit/builders/test_models.py +++ b/tests/unit/builders/test_models.py @@ -134,7 +134,7 @@ def test_build_chat_template_removes_indentation_and_newlines(): ("llava-hf/llava-1.5-7b-hf", False, True), ("Salesforce/blip2-opt-2.7b", False, True), ("microsoft/Phi-3-vision-128k-instruct", True, True), - # ("HuggingFaceTB/SmolVLM-Instruct", False, True), # requires transformers>=4.46 + ("HuggingFaceTB/SmolVLM-Instruct", False, True), ], ) def test_is_image_text_llm( From a59c9d7887c8a5de8c4e9c3263b549dea8fe4986 Mon Sep 17 00:00:00 2001 From: Matthew Persons Date: Tue, 8 Apr 2025 10:31:37 -0700 Subject: [PATCH 05/15] Lazy load skypilot (#1622) --- src/oumi/launcher/clients/sky_client.py | 53 ++++++++++++++------ src/oumi/launcher/clouds/sky_cloud.py | 31 ++++++++---- src/oumi/launcher/clusters/sky_cluster.py | 8 +-- tests/unit/launcher/clouds/test_sky_cloud.py | 38 +++++++------- tests/unit/launcher/test_launcher.py | 23 +++++++++ 5 files changed, 108 insertions(+), 45 deletions(-) diff --git a/src/oumi/launcher/clients/sky_client.py b/src/oumi/launcher/clients/sky_client.py index f6edcfc3c..0ad76a21f 100644 --- a/src/oumi/launcher/clients/sky_client.py +++ b/src/oumi/launcher/clients/sky_client.py @@ -14,20 +14,23 @@ import os from enum import Enum -from typing import Any, Optional - -import sky -import sky.data -from sky.clouds import CloudImplementationFeatures +from typing import TYPE_CHECKING, Any, Optional from oumi.core.configs import JobConfig from oumi.core.launcher import JobStatus from oumi.utils.logging import logger from oumi.utils.str_utils import try_str_to_bool +if TYPE_CHECKING: + import sky + import sky.data + -def _get_sky_cloud_from_job(job: JobConfig) -> sky.clouds.Cloud: +def _get_sky_cloud_from_job(job: JobConfig) -> "sky.clouds.Cloud": """Returns the sky.Cloud object from the JobConfig.""" + # Delay sky import: https://github.com/oumi-ai/oumi/issues/1605 + import sky + if job.resources.cloud == SkyClient.SupportedClouds.GCP.value: return sky.clouds.GCP() elif job.resources.cloud == SkyClient.SupportedClouds.RUNPOD.value: @@ -41,8 +44,11 @@ def _get_sky_cloud_from_job(job: JobConfig) -> sky.clouds.Cloud: raise ValueError(f"Unsupported cloud: {job.resources.cloud}") -def _get_sky_storage_mounts_from_job(job: JobConfig) -> dict[str, sky.data.Storage]: +def _get_sky_storage_mounts_from_job(job: JobConfig) -> dict[str, "sky.data.Storage"]: """Returns the sky.StorageMount objects from the JobConfig.""" + # Delay sky import: https://github.com/oumi-ai/oumi/issues/1605 + import sky.data + sky_mounts = {} for k, v in job.storage_mounts.items(): storage_mount = sky.data.Storage( @@ -76,8 +82,11 @@ def _get_use_spot_vm_override() -> Optional[bool]: raise ValueError(f"{_ENV_VAR_NAME} has unsupported value: '{s}'.") -def _convert_job_to_task(job: JobConfig) -> sky.Task: +def _convert_job_to_task(job: JobConfig) -> "sky.Task": """Converts a JobConfig to a sky.Task.""" + # Delay sky import: https://github.com/oumi-ai/oumi/issues/1605 + import sky + sky_cloud = _get_sky_cloud_from_job(job) use_spot_vm = _get_use_spot_vm_override() if use_spot_vm is None: @@ -123,6 +132,13 @@ class SupportedClouds(Enum): RUNPOD = "runpod" LAMBDA = "lambda" + def __init__(self): + """Initializes a new instance of the SkyClient class.""" + # Delay sky import: https://github.com/oumi-ai/oumi/issues/1605 + import sky + + self._sky_lib = sky + def launch( self, job: JobConfig, cluster_name: Optional[str] = None, **kwargs ) -> JobStatus: @@ -146,7 +162,10 @@ def launch( sky_resources = next(iter(sky_task.resources)) # This will raise an exception if the cloud does not support stopping. sky_cloud.check_features_are_supported( - sky_resources, requested_features={CloudImplementationFeatures.STOP} + sky_resources, + requested_features={ + self._sky_lib.clouds.CloudImplementationFeatures.STOP + }, ) autostop_kw = "idle_minutes_to_autostop" # Default to 60 minutes. @@ -165,7 +184,7 @@ def launch( "Will not set autostop." ) - job_id, resource_handle = sky.launch( + job_id, resource_handle = self._sky_lib.launch( sky_task, cluster_name=cluster_name, detach_run=True, @@ -188,7 +207,7 @@ def status(self) -> list[dict[str, Any]]: Returns: A list of dictionaries, each containing the status of a cluster. """ - return sky.status() + return self._sky_lib.status() def queue(self, cluster_name: str) -> list[dict]: """Gets the job queue of a cluster. @@ -199,7 +218,7 @@ def queue(self, cluster_name: str) -> list[dict]: Returns: A list of dictionaries, each containing the metadata of a cluster. """ - return sky.queue(cluster_name) + return self._sky_lib.queue(cluster_name) def cancel(self, cluster_name: str, job_id: str) -> None: """Gets the job queue of a cluster. @@ -208,7 +227,7 @@ def cancel(self, cluster_name: str, job_id: str) -> None: cluster_name: The name of the cluster to cancel the job on. job_id: The ID of the job to cancel. """ - sky.cancel(cluster_name, int(job_id)) + self._sky_lib.cancel(cluster_name, int(job_id)) def exec(self, job: JobConfig, cluster_name: str) -> str: """Executes the specified job on the target cluster. @@ -220,7 +239,9 @@ def exec(self, job: JobConfig, cluster_name: str) -> str: Returns: The ID of the job that was created. """ - job_id, _ = sky.exec(_convert_job_to_task(job), cluster_name, detach_run=True) + job_id, _ = self._sky_lib.exec( + _convert_job_to_task(job), cluster_name, detach_run=True + ) if job_id is None: raise RuntimeError("Failed to submit job.") return str(job_id) @@ -231,7 +252,7 @@ def stop(self, cluster_name: str) -> None: Args: cluster_name: The name of the cluster to stop. """ - sky.stop(cluster_name) + self._sky_lib.stop(cluster_name) def down(self, cluster_name: str) -> None: """Tears down the target cluster. @@ -239,4 +260,4 @@ def down(self, cluster_name: str) -> None: Args: cluster_name: The name of the cluster to tear down. """ - sky.down(cluster_name) + self._sky_lib.down(cluster_name) diff --git a/src/oumi/launcher/clouds/sky_cloud.py b/src/oumi/launcher/clouds/sky_cloud.py index 4cc54bcee..efc14a74f 100644 --- a/src/oumi/launcher/clouds/sky_cloud.py +++ b/src/oumi/launcher/clouds/sky_cloud.py @@ -14,8 +14,6 @@ from typing import Optional, TypeVar -import sky - from oumi.core.configs import JobConfig from oumi.core.launcher import BaseCloud, BaseCluster, JobStatus from oumi.core.registry import register_cloud_builder @@ -28,13 +26,25 @@ class SkyCloud(BaseCloud): """A resource pool capable of creating clusters using Sky Pilot.""" - def __init__(self, cloud_name: str, client: SkyClient): + @property + def _client(self) -> SkyClient: + """Returns the SkyClient instance.""" + # Instantiating a SkyClient imports sky. + # Delay sky import: https://github.com/oumi-ai/oumi/issues/1605 + if not self._sky_client: + self._sky_client = SkyClient() + return self._sky_client + + def __init__(self, cloud_name: str): """Initializes a new instance of the SkyCloud class.""" self._cloud_name = cloud_name - self._client = client + self._sky_client: Optional[SkyClient] = None def _get_clusters_by_class(self, cloud_class: type[T]) -> list[BaseCluster]: """Gets the appropriate clusters of type T.""" + # Delay sky import: https://github.com/oumi-ai/oumi/issues/1605 + import sky + return [ SkyCluster(cluster["name"], self._client) for cluster in self._client.status() @@ -62,6 +72,9 @@ def get_cluster(self, name) -> Optional[BaseCluster]: def list_clusters(self) -> list[BaseCluster]: """Lists the active clusters on this cloud.""" + # Delay sky import: https://github.com/oumi-ai/oumi/issues/1605 + import sky + if self._cloud_name == SkyClient.SupportedClouds.GCP.value: return self._get_clusters_by_class(sky.clouds.GCP) elif self._cloud_name == SkyClient.SupportedClouds.RUNPOD.value: @@ -78,28 +91,28 @@ def list_clusters(self) -> list[BaseCluster]: @register_cloud_builder("runpod") def runpod_cloud_builder() -> SkyCloud: """Builds a SkyCloud instance for runpod.""" - return SkyCloud(SkyClient.SupportedClouds.RUNPOD.value, SkyClient()) + return SkyCloud(SkyClient.SupportedClouds.RUNPOD.value) @register_cloud_builder("gcp") def gcp_cloud_builder() -> SkyCloud: """Builds a SkyCloud instance for Google Cloud Platform.""" - return SkyCloud(SkyClient.SupportedClouds.GCP.value, SkyClient()) + return SkyCloud(SkyClient.SupportedClouds.GCP.value) @register_cloud_builder("lambda") def lambda_cloud_builder() -> SkyCloud: """Builds a SkyCloud instance for Lambda.""" - return SkyCloud(SkyClient.SupportedClouds.LAMBDA.value, SkyClient()) + return SkyCloud(SkyClient.SupportedClouds.LAMBDA.value) @register_cloud_builder("aws") def aws_cloud_builder() -> SkyCloud: """Builds a SkyCloud instance for AWS.""" - return SkyCloud(SkyClient.SupportedClouds.AWS.value, SkyClient()) + return SkyCloud(SkyClient.SupportedClouds.AWS.value) @register_cloud_builder("azure") def azure_cloud_builder() -> SkyCloud: """Builds a SkyCloud instance for Azure.""" - return SkyCloud(SkyClient.SupportedClouds.AZURE.value, SkyClient()) + return SkyCloud(SkyClient.SupportedClouds.AZURE.value) diff --git a/src/oumi/launcher/clusters/sky_cluster.py b/src/oumi/launcher/clusters/sky_cluster.py index 6206d08c9..f2b5e88f4 100644 --- a/src/oumi/launcher/clusters/sky_cluster.py +++ b/src/oumi/launcher/clusters/sky_cluster.py @@ -14,8 +14,6 @@ from typing import Any, Optional -import sky.exceptions - from oumi.core.configs import JobConfig from oumi.core.launcher import BaseCluster, JobStatus from oumi.launcher.clients.sky_client import SkyClient @@ -26,6 +24,10 @@ class SkyCluster(BaseCluster): def __init__(self, name: str, client: SkyClient) -> None: """Initializes a new instance of the SkyCluster class.""" + # Delay sky import: https://github.com/oumi-ai/oumi/issues/1605 + import sky.exceptions + + self._sky_exceptions = sky.exceptions self._name = name self._client = client @@ -78,7 +80,7 @@ def get_jobs(self) -> list[JobStatus]: self._convert_sky_job_to_status(job) for job in self._client.queue(self.name()) ] - except sky.exceptions.ClusterNotUpError: + except self._sky_exceptions.ClusterNotUpError: return [] def cancel_job(self, job_id: str) -> JobStatus: diff --git a/tests/unit/launcher/clouds/test_sky_cloud.py b/tests/unit/launcher/clouds/test_sky_cloud.py index 041ec2f5e..0d7a1a111 100644 --- a/tests/unit/launcher/clouds/test_sky_cloud.py +++ b/tests/unit/launcher/clouds/test_sky_cloud.py @@ -16,7 +16,11 @@ # @pytest.fixture def mock_sky_client(): - yield Mock(spec=SkyClient) + with patch("oumi.launcher.clouds.sky_cloud.SkyClient") as client: + client.SupportedClouds = SkyClient.SupportedClouds + client_instance = Mock(spec=SkyClient) + client.return_value = client_instance + yield client_instance @pytest.fixture @@ -122,7 +126,7 @@ def test_sky_cloud_up_cluster(mock_sky_client, mock_sky_cluster): }, ] mock_sky_client.launch.return_value = expected_job_status - cloud = SkyCloud("gcp", mock_sky_client) + cloud = SkyCloud("gcp") job_status = cloud.up_cluster(_get_default_job("gcp"), "new_cluster_name") mock_sky_client.launch.assert_called_once_with( _get_default_job("gcp"), "new_cluster_name" @@ -183,7 +187,7 @@ def test_sky_cloud_up_cluster_kwargs(mock_sky_client, mock_sky_cluster): }, ] mock_sky_client.launch.return_value = expected_job_status - cloud = SkyCloud("gcp", mock_sky_client) + cloud = SkyCloud("gcp") job_status = cloud.up_cluster( _get_default_job("gcp"), "new_cluster_name", custom_kwarg=1 ) @@ -246,14 +250,14 @@ def test_sky_cloud_up_cluster_no_name(mock_sky_client, mock_sky_cluster): }, ] mock_sky_client.launch.return_value = expected_job_status - cloud = SkyCloud("gcp", mock_sky_client) + cloud = SkyCloud("gcp") job_status = cloud.up_cluster(_get_default_job("gcp"), None) mock_sky_client.launch.assert_called_once_with(_get_default_job("gcp"), None) assert job_status == expected_job_status def test_sky_cloud_list_clusters_gcp(mock_sky_client): - cloud = SkyCloud("gcp", mock_sky_client) + cloud = SkyCloud("gcp") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -301,7 +305,7 @@ def test_sky_cloud_list_clusters_gcp(mock_sky_client): def test_sky_cloud_list_clusters_runpod(mock_sky_client): - cloud = SkyCloud("runpod", mock_sky_client) + cloud = SkyCloud("runpod") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -339,7 +343,7 @@ def test_sky_cloud_list_clusters_runpod(mock_sky_client): def test_sky_cloud_list_clusters_lambda(mock_sky_client): - cloud = SkyCloud("lambda", mock_sky_client) + cloud = SkyCloud("lambda") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -377,7 +381,7 @@ def test_sky_cloud_list_clusters_lambda(mock_sky_client): def test_sky_cloud_list_clusters_lambda_no_cluster(mock_sky_client): - cloud = SkyCloud("lambda", mock_sky_client) + cloud = SkyCloud("lambda") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -406,7 +410,7 @@ def test_sky_cloud_list_clusters_lambda_no_cluster(mock_sky_client): def test_sky_cloud_list_clusters_lambda_multiple_cluster(mock_sky_client): - cloud = SkyCloud("lambda", mock_sky_client) + cloud = SkyCloud("lambda") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -458,7 +462,7 @@ def test_sky_cloud_list_clusters_lambda_multiple_cluster(mock_sky_client): def test_sky_cloud_list_clusters_invalid_cloud(mock_sky_client): - cloud = SkyCloud("fake_cloud", mock_sky_client) + cloud = SkyCloud("fake_cloud") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -506,7 +510,7 @@ def test_sky_cloud_list_clusters_invalid_cloud(mock_sky_client): def test_sky_cloud_get_cluster_gcp_success(mock_sky_client): - cloud = SkyCloud("gcp", mock_sky_client) + cloud = SkyCloud("gcp") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -545,7 +549,7 @@ def test_sky_cloud_get_cluster_gcp_success(mock_sky_client): def test_sky_cloud_get_cluster_runpod_success(mock_sky_client): - cloud = SkyCloud("runpod", mock_sky_client) + cloud = SkyCloud("runpod") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -584,7 +588,7 @@ def test_sky_cloud_get_cluster_runpod_success(mock_sky_client): def test_sky_cloud_get_cluster_lambda_success(mock_sky_client): - cloud = SkyCloud("lambda", mock_sky_client) + cloud = SkyCloud("lambda") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -623,7 +627,7 @@ def test_sky_cloud_get_cluster_lambda_success(mock_sky_client): def test_sky_cloud_get_cluster_aws_success(mock_sky_client): - cloud = SkyCloud("aws", mock_sky_client) + cloud = SkyCloud("aws") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -662,7 +666,7 @@ def test_sky_cloud_get_cluster_aws_success(mock_sky_client): def test_sky_cloud_get_cluster_azure_success(mock_sky_client): - cloud = SkyCloud("azure", mock_sky_client) + cloud = SkyCloud("azure") mock_gcp_cluster = Mock(spec=sky.clouds.GCP) mock_gcp_handler = Mock() mock_gcp_handler.launched_resources = Mock() @@ -701,7 +705,7 @@ def test_sky_cloud_get_cluster_azure_success(mock_sky_client): def test_sky_cloud_get_cluster_failure_wrong_cloud(mock_sky_client): - cloud = SkyCloud("gcp", mock_sky_client) + cloud = SkyCloud("gcp") mock_runpod_cluster = Mock(spec=sky.clouds.RunPod) mock_runpod_handler = Mock() @@ -731,7 +735,7 @@ def test_sky_cloud_get_cluster_failure_wrong_cloud(mock_sky_client): def test_sky_cloud_get_cluster_failure_empty(mock_sky_client): - cloud = SkyCloud("gcp", mock_sky_client) + cloud = SkyCloud("gcp") mock_sky_client.status.return_value = [] cluster = cloud.get_cluster("gcp_cluster") mock_sky_client.status.assert_called_once() diff --git a/tests/unit/launcher/test_launcher.py b/tests/unit/launcher/test_launcher.py index dad839473..b7d525838 100644 --- a/tests/unit/launcher/test_launcher.py +++ b/tests/unit/launcher/test_launcher.py @@ -1,3 +1,4 @@ +from multiprocessing import Process, set_start_method from unittest.mock import Mock, patch import pytest @@ -1165,3 +1166,25 @@ def test_launcher_export_methods(mock_registry): assert LAUNCHER.stop == stop assert LAUNCHER.get_cloud == get_cloud assert LAUNCHER.which_clouds == which_clouds + + +def _verify_no_extra_import(extra_module: str): + """Verifies that extra modules are not imported.""" + import sys + + import oumi.launcher # noqa + + assert extra_module not in sys.modules, f"{extra_module} was imported." + + +def test_launcher_no_sky_dependency(): + # Ensure that sky is lazy loaded so it doesn't cause DB contention in multinode + # jobs: https://github.com/oumi-ai/oumi/issues/1605 + + set_start_method("spawn", force=True) + process = Process(target=_verify_no_extra_import, args=["sky"]) + process.start() + process.join() + assert ( + process.exitcode == 0 + ), "Sky was imported as part of the launcher module. This is a regression." From 124b2fd2cb92727ff4fbe10fba36557bf892d4db Mon Sep 17 00:00:00 2001 From: Yushi Homma Date: Wed, 9 Apr 2025 16:59:32 -0700 Subject: [PATCH 06/15] Add additional_model_kwargs and additional_trainer_kwargs to train function (#1624) --- src/oumi/__init__.py | 12 ++++++++++-- src/oumi/train.py | 11 +++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/oumi/__init__.py b/src/oumi/__init__.py index 2d3cec2cf..9150cd6d7 100644 --- a/src/oumi/__init__.py +++ b/src/oumi/__init__.py @@ -240,11 +240,19 @@ def judge_dataset(config: JudgeConfig, dataset: BaseSftDataset) -> list[dict[str return oumi.judge.judge_dataset(config, dataset) -def train(config: TrainingConfig, **kwargs) -> None: +def train( + config: TrainingConfig, + additional_model_kwargs: dict[str, Any] | None = None, + additional_trainer_kwargs: dict[str, Any] | None = None, +) -> None: """Trains a model using the provided configuration.""" import oumi.train - return oumi.train.train(config, *kwargs) + return oumi.train.train( + config, + additional_model_kwargs=additional_model_kwargs, + additional_trainer_kwargs=additional_trainer_kwargs, + ) __all__ = [ diff --git a/src/oumi/train.py b/src/oumi/train.py index a09ade722..4c34fffe6 100644 --- a/src/oumi/train.py +++ b/src/oumi/train.py @@ -189,6 +189,7 @@ def _create_optional_training_kwargs( metrics_function: Optional[Callable], reward_functions: list[Callable], collator: Optional[Callable], + additional_trainer_kwargs: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: kwargs: dict[str, Any] = {"processing_class": tokenizer} if trainer_type == TrainerType.OUMI: @@ -204,10 +205,15 @@ def _create_optional_training_kwargs( if collator: raise ValueError(f"collator isn't supported for {trainer_type}") kwargs["reward_funcs"] = reward_functions + kwargs.update(additional_trainer_kwargs or {}) return kwargs -def train(config: TrainingConfig, **kwargs) -> None: +def train( + config: TrainingConfig, + additional_model_kwargs: Optional[dict[str, Any]] = None, + additional_trainer_kwargs: Optional[dict[str, Any]] = None, +) -> None: """Trains a model using the provided configuration.""" _START_TIME = time.time() @@ -275,7 +281,7 @@ def train(config: TrainingConfig, **kwargs) -> None: model = build_model( model_params=config.model, peft_params=config.peft if use_peft else None, - *kwargs, + **(additional_model_kwargs or {}), ) if use_peft: @@ -354,6 +360,7 @@ def train(config: TrainingConfig, **kwargs) -> None: metrics_function, reward_functions, collator, + additional_trainer_kwargs=additional_trainer_kwargs, ) # Reclaim memory before training starts. From 2b62dbb247d11dedeaf3b72a21e64f50722188ef Mon Sep 17 00:00:00 2001 From: Joe W <761337+jrwana@users.noreply.github.com> Date: Thu, 10 Apr 2025 07:37:29 +0700 Subject: [PATCH 07/15] Added 3 Pixmo vision-language datasets (#1523) Co-authored-by: castielle <761337+castielle@users.noreply.github.com> Co-authored-by: nikg4 Co-authored-by: William Zeng <10782997+wizeng23@users.noreply.github.com> --- src/oumi/datasets/__init__.py | 8 +++ src/oumi/datasets/vision_language/__init__.py | 8 +++ .../pixmo_ask_model_anything.py | 53 ++++++++++++++ .../datasets/vision_language/pixmo_cap.py | 57 +++++++++++++++ .../datasets/vision_language/pixmo_cap_qa.py | 63 ++++++++++++++++ .../test_sft_vision_datasets_load_datasets.py | 27 +++++++ tests/unit/datasets/test_pixmo.py | 71 +++++++++++++++++++ 7 files changed, 287 insertions(+) create mode 100644 src/oumi/datasets/vision_language/pixmo_ask_model_anything.py create mode 100644 src/oumi/datasets/vision_language/pixmo_cap.py create mode 100644 src/oumi/datasets/vision_language/pixmo_cap_qa.py create mode 100644 tests/unit/datasets/test_pixmo.py diff --git a/src/oumi/datasets/__init__.py b/src/oumi/datasets/__init__.py index e001ac948..089e0ec82 100644 --- a/src/oumi/datasets/__init__.py +++ b/src/oumi/datasets/__init__.py @@ -78,6 +78,11 @@ from oumi.datasets.vision_language.llava_instruct_mix_vsft import ( LlavaInstructMixVsftDataset, ) +from oumi.datasets.vision_language.pixmo_ask_model_anything import ( + PixmoAskModelAnythingDataset, +) +from oumi.datasets.vision_language.pixmo_cap import PixmoCapDataset +from oumi.datasets.vision_language.pixmo_cap_qa import PixmoCapQADataset from oumi.datasets.vision_language.vision_jsonlines import VLJsonlinesDataset __all__ = [ @@ -105,6 +110,9 @@ "OpenO1SFTDataset", "OrpoDpoMix40kDataset", "PileV1Dataset", + "PixmoAskModelAnythingDataset", + "PixmoCapDataset", + "PixmoCapQADataset", "PromptResponseDataset", "RedPajamaDataV1Dataset", "RedPajamaDataV2Dataset", diff --git a/src/oumi/datasets/vision_language/__init__.py b/src/oumi/datasets/vision_language/__init__.py index e7f435f97..55b832bae 100644 --- a/src/oumi/datasets/vision_language/__init__.py +++ b/src/oumi/datasets/vision_language/__init__.py @@ -21,6 +21,11 @@ LlavaInstructMixVsftDataset, ) from oumi.datasets.vision_language.mnist_sft import MnistSftDataset +from oumi.datasets.vision_language.pixmo_ask_model_anything import ( + PixmoAskModelAnythingDataset, +) +from oumi.datasets.vision_language.pixmo_cap import PixmoCapDataset +from oumi.datasets.vision_language.pixmo_cap_qa import PixmoCapQADataset from oumi.datasets.vision_language.the_cauldron import TheCauldronDataset from oumi.datasets.vision_language.vision_jsonlines import VLJsonlinesDataset from oumi.datasets.vision_language.vqav2_small import Vqav2SmallDataset @@ -31,6 +36,9 @@ "Flickr30kDataset", "LlavaInstructMixVsftDataset", "MnistSftDataset", + "PixmoAskModelAnythingDataset", + "PixmoCapDataset", + "PixmoCapQADataset", "VLJsonlinesDataset", "Vqav2SmallDataset", "TheCauldronDataset", diff --git a/src/oumi/datasets/vision_language/pixmo_ask_model_anything.py b/src/oumi/datasets/vision_language/pixmo_ask_model_anything.py new file mode 100644 index 000000000..4e18bb231 --- /dev/null +++ b/src/oumi/datasets/vision_language/pixmo_ask_model_anything.py @@ -0,0 +1,53 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing_extensions import override # noqa: I001 + +from oumi.core.datasets import VisionLanguageSftDataset +from oumi.core.registry import register_dataset +from oumi.core.types.conversation import ( + ContentItem, + Conversation, + Message, + Role, + Type, +) + + +@register_dataset("allenai/pixmo-ask-model-anything") +class PixmoAskModelAnythingDataset(VisionLanguageSftDataset): + """Dataset class for the `allenai/pixmo-docs` dataset. + + The dataset is affected by some image URLs having a 404 issue. + """ + + default_dataset = "allenai/pixmo-ask-model-anything" + + @override + def transform_conversation(self, example: dict) -> Conversation: + """Transform the example into a Conversation object.""" + conversation = Conversation( + messages=[ + Message( + role=Role.USER, + content=[ + ContentItem(type=Type.IMAGE_URL, content=example["image_url"]), + ContentItem(type=Type.TEXT, content=example["question"]), + ], + ), + Message(role=Role.ASSISTANT, content=example["answer"]), + ] + ) + + return conversation diff --git a/src/oumi/datasets/vision_language/pixmo_cap.py b/src/oumi/datasets/vision_language/pixmo_cap.py new file mode 100644 index 000000000..e7c4df461 --- /dev/null +++ b/src/oumi/datasets/vision_language/pixmo_cap.py @@ -0,0 +1,57 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing_extensions import override # noqa: I001 + +from oumi.core.datasets import VisionLanguageSftDataset +from oumi.core.registry import register_dataset +from oumi.core.types.conversation import ( + ContentItem, + Conversation, + Message, + Role, + Type, +) + + +@register_dataset("allenai/pixmo-cap") +class PixmoCapDataset(VisionLanguageSftDataset): + """Dataset class for the `allenai/pixmo-cap` dataset. + + The dataset is affected by some image URLs having a 404 issue. + """ + + default_dataset = "allenai/pixmo-cap" + + @override + def transform_conversation(self, example: dict) -> Conversation: + """Transform the example into a Conversation object. + + A "transcripts" column is also available but not used yet. + """ + input_text = "Describe this image:" + + messages: list[Message] = [] + messages.append( + Message( + role=Role.USER, + content=[ + ContentItem(type=Type.IMAGE_URL, content=example["image_url"]), + ContentItem(type=Type.TEXT, content=input_text), + ], + ) + ) + messages.append(Message(role=Role.ASSISTANT, content=example["caption"])) + + return Conversation(messages=messages) diff --git a/src/oumi/datasets/vision_language/pixmo_cap_qa.py b/src/oumi/datasets/vision_language/pixmo_cap_qa.py new file mode 100644 index 000000000..e019d1708 --- /dev/null +++ b/src/oumi/datasets/vision_language/pixmo_cap_qa.py @@ -0,0 +1,63 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing_extensions import override # noqa: I001 + +from oumi.core.datasets import VisionLanguageSftDataset +from oumi.core.registry import register_dataset +from oumi.core.types.conversation import ( + ContentItem, + Conversation, + Message, + Role, + Type, +) + + +@register_dataset("allenai/pixmo-cap-qa") +class PixmoCapQADataset(VisionLanguageSftDataset): + """Dataset class for the `allenai/pixmo-cap-qa` dataset. + + The dataset is affected by some image URLs having a 404 issue. + """ + + default_dataset = "allenai/pixmo-cap-qa" + + @override + def transform_conversation(self, example: dict) -> Conversation: + """Transform the example into a Conversation object. + + Sample "question": "[USER] Can you come up with a joke? [ASSISTANT]" + It starts with a [USER] and ends with an [ASSISTANT] role tag. + The Assistant response appears in the "answer" field. + """ + messages: list[Message] = [] + messages.append( + Message( + role=Role.USER, + content=[ + ContentItem(type=Type.IMAGE_URL, content=example["image_url"]), + ], + ) + ) + questions = example["question"].strip().split("\n") + for question in questions: + if question.startswith("[USER]"): + question = question[len("[USER]") :].strip() + messages.append(Message(role=Role.USER, content=question)) + elif question.startswith("[ASSISTANT]"): + answer = question[len("[ASSISTANT]") :].strip() + messages.append(Message(role=Role.ASSISTANT, content=answer)) + messages.append(Message(role=Role.ASSISTANT, content=example["answer"])) + return Conversation(messages=messages) diff --git a/tests/integration/datasets/test_sft_vision_datasets_load_datasets.py b/tests/integration/datasets/test_sft_vision_datasets_load_datasets.py index c6b791df0..aef19eea7 100644 --- a/tests/integration/datasets/test_sft_vision_datasets_load_datasets.py +++ b/tests/integration/datasets/test_sft_vision_datasets_load_datasets.py @@ -106,6 +106,33 @@ def _get_all_sft_vision_dataset_infos() -> list[LoadDatasetInfo]: max_rows=64, expected_rows=64, ), + LoadDatasetInfo( + dataset_name="allenai/pixmo-ask-model-anything", + model_name=_DEFAULT_MODEL_NAME, + dataset_split="train[10:20]", # 404 error for some image URLs + chat_template=_DEFAULT_CHAT_TEMPLATE, + trust_remote_code=True, + max_rows=64, + expected_rows=None, + ), + LoadDatasetInfo( + dataset_name="allenai/pixmo-cap", + model_name=_DEFAULT_MODEL_NAME, + dataset_split="train[50:51]", # 429 error for some image URLs + chat_template=_DEFAULT_CHAT_TEMPLATE, + trust_remote_code=True, + max_rows=64, + expected_rows=None, + ), + LoadDatasetInfo( + dataset_name="allenai/pixmo-cap-qa", + model_name=_DEFAULT_MODEL_NAME, + dataset_split="train[10:20]", # 404 error for some image URLs + chat_template=_DEFAULT_CHAT_TEMPLATE, + trust_remote_code=True, + max_rows=64, + expected_rows=None, + ), ] all_excluded_dataset_names_normalized = set( diff --git a/tests/unit/datasets/test_pixmo.py b/tests/unit/datasets/test_pixmo.py new file mode 100644 index 000000000..912b85ea7 --- /dev/null +++ b/tests/unit/datasets/test_pixmo.py @@ -0,0 +1,71 @@ +from unittest import mock + +import pytest + +from oumi.core.types.conversation import Conversation, Role +from oumi.datasets.vision_language import ( + PixmoAskModelAnythingDataset, + PixmoCapDataset, + PixmoCapQADataset, +) + + +@pytest.fixture +def sample_pixmo_ask_model_anything_entry(): + return { + "image_url": "http://oumi.ai/test.png", + "image_sha256": "1234567890", + "question": "What type of machine is this?", + "answer": "This is a vintage-style popcorn cart.", + } + + +@pytest.fixture +def sample_pixmo_cap_entry(): + return { + "image_url": "http://oumi.ai/test.png", + "caption": "This photograph depicts a striking black bird", + "transcripts": ["a", "b", "c"], + } + + +@pytest.fixture +def sample_pixmo_cap_qa_entry(): + return { + "image_url": "http://oumi.ai/test.png", + "question": "[USER] Color? [ASSISTANT] Blue [USER] Time?[ASSISTANT]", + "answer": "Noon", + } + + +def test_pixmo_ask_model_anything_dataset(sample_pixmo_ask_model_anything_entry): + with mock.patch.object( + PixmoAskModelAnythingDataset, "__init__", return_value=None + ) as _: + dataset = PixmoAskModelAnythingDataset() + conversation = dataset.transform_conversation(sample_pixmo_ask_model_anything_entry) + assert isinstance(conversation, Conversation) + assert len(conversation.messages) == 2 + assert conversation.messages[0].role == Role.USER + assert conversation.messages[1].role == Role.ASSISTANT + + +def test_pixmo_cap_dataset(sample_pixmo_cap_entry): + with mock.patch.object(PixmoCapDataset, "__init__", return_value=None) as _: + dataset = PixmoCapDataset() + conversation = dataset.transform_conversation(sample_pixmo_cap_entry) + assert isinstance(conversation, Conversation) + assert len(conversation.messages) >= 2 + assert conversation.messages[0].role == Role.USER + assert conversation.messages[1].role == Role.ASSISTANT + + +def test_pixmo_cap_qa_dataset(sample_pixmo_cap_qa_entry): + with mock.patch.object(PixmoCapQADataset, "__init__", return_value=None) as _: + dataset = PixmoCapQADataset() + conversation = dataset.transform_conversation(sample_pixmo_cap_qa_entry) + assert isinstance(conversation, Conversation) + assert len(conversation.messages) >= 3 + assert conversation.messages[0].role == Role.USER + assert conversation.messages[1].role == Role.USER + assert conversation.messages[2].role == Role.ASSISTANT From c1a36635c26dafc5c103dfe2f9f8e92dc92dc930 Mon Sep 17 00:00:00 2001 From: William Zeng <10782997+wizeng23@users.noreply.github.com> Date: Thu, 10 Apr 2025 17:09:08 -0700 Subject: [PATCH 08/15] [GRPO] Add notebook to demonstrate GRPO & evaluation for letter counting (#1625) --- configs/examples/grpo_tldr/gcp_job.yaml | 1 - .../letter_counting/evaluation/eval.yaml | 3 +- .../letter_counting/evaluation/gcp_job.yaml | 8 +- .../letter_counting/grpo/gcp_job.yaml | 2 - .../examples/letter_counting/grpo/train.yaml | 12 +- .../halloumi/halloumi_eval_notebook.ipynb | 2 +- .../halloumi_inference_notebook.ipynb | 2 +- ...n a Letter Counting Model using GRPO.ipynb | 918 ++++++++++++++++++ src/oumi/datasets/grpo/letter_count.py | 19 +- .../grpo/rewards/count_letters_rewards.py | 23 +- .../evaluation/registry/count_letters_task.py | 2 +- .../rewards/test_count_letters_rewards.py | 20 +- 12 files changed, 972 insertions(+), 40 deletions(-) create mode 100644 notebooks/Oumi - Train a Letter Counting Model using GRPO.ipynb diff --git a/configs/examples/grpo_tldr/gcp_job.yaml b/configs/examples/grpo_tldr/gcp_job.yaml index 204ba1f50..4bc0d13d6 100644 --- a/configs/examples/grpo_tldr/gcp_job.yaml +++ b/configs/examples/grpo_tldr/gcp_job.yaml @@ -32,7 +32,6 @@ envs: setup: | set -e pip install uv && uv pip install oumi[gpu] - pip install -U flash-attn --no-build-isolation run: | set -e # Exit if any command failed. diff --git a/configs/examples/letter_counting/evaluation/eval.yaml b/configs/examples/letter_counting/evaluation/eval.yaml index 23d874ffc..a38274c33 100644 --- a/configs/examples/letter_counting/evaluation/eval.yaml +++ b/configs/examples/letter_counting/evaluation/eval.yaml @@ -3,6 +3,7 @@ # Requirements: # - Run `pip install oumi[gpu]` # - Log into HF: `huggingface-cli login` +# - Request access to Llama 3.2: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct # # Usage: # oumi evaluate -c oumi://configs/examples/letter_counting/evaluation/eval.yaml @@ -14,7 +15,7 @@ # - Other eval configs: configs/**/evaluation/ model: - model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + model_name: "meta-llama/Llama-3.2-3B-Instruct" model_max_length: 131072 torch_dtype_str: "bfloat16" attn_implementation: "sdpa" diff --git a/configs/examples/letter_counting/evaluation/gcp_job.yaml b/configs/examples/letter_counting/evaluation/gcp_job.yaml index 3b4772e28..6707e8689 100644 --- a/configs/examples/letter_counting/evaluation/gcp_job.yaml +++ b/configs/examples/letter_counting/evaluation/gcp_job.yaml @@ -3,6 +3,7 @@ # Requirements: # - Set up SkyPilot GCP: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html#setup # - Log into HF: `huggingface-cli login` +# - Request access to Llama 3.2: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct # # Usage: # oumi launch up -c oumi://configs/examples/letter_counting/evaluation/gcp_job.yaml --cluster letter-counting-eval @@ -30,13 +31,13 @@ envs: # NOTE: For SFT, update this to point to your model checkpoint. # NOTE: For LoRA, instead update this to point to your LoRA adapter. # The base model will be inferred automatically. - MODEL_CHECKPOINT_DIR: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B + MODEL_CHECKPOINT_DIR: meta-llama/Llama-3.2-3B-Instruct WANDB_PROJECT: oumi-eval OUMI_RUN_NAME: letter-counting.eval setup: | set -e - pip install uv && uv pip install oumi[gpu,evaluation] + pip install uv && uv pip install oumi[gpu] run: | set -e # Exit if any command failed. @@ -50,8 +51,7 @@ run: | echo "Starting evaluation for ${MODEL_CHECKPOINT_DIR} ..." set -x - accelerate launch \ - -m oumi evaluate \ + oumi evaluate \ -c oumi://configs/examples/letter_counting/evaluation/eval.yaml \ --run_name "${OUMI_RUN_NAME}.${SKYPILOT_TASK_ID}" \ --model.model_name "${MODEL_CHECKPOINT_DIR}" diff --git a/configs/examples/letter_counting/grpo/gcp_job.yaml b/configs/examples/letter_counting/grpo/gcp_job.yaml index 967071826..b07dc4b76 100644 --- a/configs/examples/letter_counting/grpo/gcp_job.yaml +++ b/configs/examples/letter_counting/grpo/gcp_job.yaml @@ -33,9 +33,7 @@ envs: setup: | set -e - # vLLM needed for vLLM-powered generation during GRPO training. pip install uv && uv pip install oumi[gpu] - pip install -U flash-attn --no-build-isolation run: | set -e # Exit if any command failed. diff --git a/configs/examples/letter_counting/grpo/train.yaml b/configs/examples/letter_counting/grpo/train.yaml index b8b5483c1..530797142 100644 --- a/configs/examples/letter_counting/grpo/train.yaml +++ b/configs/examples/letter_counting/grpo/train.yaml @@ -3,6 +3,7 @@ # Requirements: # - Log into WandB (`wandb login`) or disable `enable_wandb` # - Log into HF: `huggingface-cli login` +# - Request access to Llama 3.2: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct # # Usage: # oumi train -c oumi://configs/examples/letter_counting/grpo/train.yaml @@ -14,7 +15,7 @@ # - Other training configs: configs/**/pretraining/, configs/**/sft/, configs/**/dpo/ model: - model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + model_name: "meta-llama/Llama-3.2-3B-Instruct" model_max_length: 8192 torch_dtype_str: "bfloat16" attn_implementation: "sdpa" @@ -29,8 +30,11 @@ training: trainer_type: "TRL_GRPO" save_steps: 500 max_steps: 500 - per_device_train_batch_size: 3 + per_device_train_batch_size: 2 gradient_accumulation_steps: 1 + learning_rate: 5e-5 + lr_scheduler_type: "cosine" + warmup_steps: 20 reward_functions: ["count_letters"] @@ -38,11 +42,11 @@ training: gradient_checkpointing_kwargs: use_reentrant: False ddp_find_unused_parameters: False - optimizer: "adamw_torch_fused" + optimizer: "adafactor" compile: True # Set to False if `grpo.use_vllm` is enabled: https://github.com/vllm-project/vllm/issues/12783 grpo: - num_generations: 6 + num_generations: 4 use_vllm: False dataloader_num_workers: "auto" diff --git a/configs/projects/halloumi/halloumi_eval_notebook.ipynb b/configs/projects/halloumi/halloumi_eval_notebook.ipynb index e18d6429c..d01faa0c2 100644 --- a/configs/projects/halloumi/halloumi_eval_notebook.ipynb +++ b/configs/projects/halloumi/halloumi_eval_notebook.ipynb @@ -55,7 +55,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install git+https://github.com/oumi-ai/oumi.git" + "%pip install oumi" ] }, { diff --git a/configs/projects/halloumi/halloumi_inference_notebook.ipynb b/configs/projects/halloumi/halloumi_inference_notebook.ipynb index 44a3fd636..e21527130 100644 --- a/configs/projects/halloumi/halloumi_inference_notebook.ipynb +++ b/configs/projects/halloumi/halloumi_inference_notebook.ipynb @@ -51,7 +51,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install git+https://github.com/oumi-ai/oumi.git" + "%pip install oumi" ] }, { diff --git a/notebooks/Oumi - Train a Letter Counting Model using GRPO.ipynb b/notebooks/Oumi - Train a Letter Counting Model using GRPO.ipynb new file mode 100644 index 000000000..23d903ab0 --- /dev/null +++ b/notebooks/Oumi - Train a Letter Counting Model using GRPO.ipynb @@ -0,0 +1,918 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "\n", + "[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)\n", + "[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)\n", + "[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)\n", + "\"Open\n", + "
\n", + "\n", + "👋 Welcome to Open Universal Machine Intelligence (Oumi)!\n", + "\n", + "🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](https://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.\n", + "\n", + "🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.\n", + "\n", + "⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train a Letter Counting Model using GRPO\n", + "\n", + "This notebook delves into a fun, popular question to ask LLMs: \"How Many R’s Are in the Word Strawberry?\". First, we will use a custom evaluation function to evaluate many popular models on the task of counting letters in words. Then, we will use Group Relative Policy Optimization (GRPO), a reinforcement learning algorithm, to train Llama 3.2 3B to improve its performance on this task.\n", + "\n", + "This notebook includes cell outputs, but some irrelevant outputs (ex. install lines, warnings) are modified/removed for readability." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "### Machine Requirements\n", + "\n", + "This notebook runs both model evaluation and GRPO training, which require 8GB and 40GB VRAM, respectively.\n", + "\n", + "❗**NOTICE:** If you're running this notebook on Colab using a T4 GPU, it's not possible to run training due to memory requirements. To run evaluation, some adjustments need to be made as vLLM doesn't support T4 GPUs. This will be explained in the evaluation section.\n", + "\n", + "If your local machine cannot run this notebook, you can instead run this notebook on a cloud platform. The following demonstrates how to open a VSCode instance backed by a GCP node with 4 A100 GPUs, from which the notebook can be run. It is possible to run this notebook on just 1 GPU, but you will need make some adjustments to training parameters, which will be explained in the training section.\n", + "\n", + "```bash\n", + "# Run on your local machine\n", + "gcloud auth application-default login # Authenticate with GCP\n", + "make gcpcode ARGS=\"--resources.accelerators A100:4\"\n", + "```\n", + "\n", + "### Oumi Installation\n", + "\n", + "First, let's install Oumi and vLLM (part of the `gpu` optional dependencies). You can find more detailed instructions about Oumi installation [here](https://oumi.ai/docs/en/latest/get_started/installation.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install git+https://github.com/oumi-ai/oumi.git\n", + "%pip install \"vllm>=0.7.3,<0.8.0\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Remote API Access\n", + "\n", + "As part of this notebook, you can evaluate frontier models from Open AI, Google, Anthropic, and Meta on the letter counting task. If you want to evaluate any of these models, set the corresponding fields below." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\" # Set your OpenAI API key here.\n", + "os.environ[\"GEMINI_API_KEY\"] = \"\" # Set your Gemini API key here.\n", + "os.environ[\"ANTHROPIC_API_KEY\"] = \"\" # Set your Anthropic API key here.\n", + "\n", + "# Set your GCP project id and region, if you want to query Llama 3.1 405B in Vertex.\n", + "REGION = \"\" # Set your GCP region here.\n", + "PROJECT_ID = \"\" # Set your GCP project id here." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tutorial Directory\n", + "\n", + "Finally, we'll set up a directory to use for this tutorial, and some environment variables." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "tutorial_dir = \"letter_counting_tutorial\"\n", + "\n", + "Path(tutorial_dir).mkdir(parents=True, exist_ok=True)\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\" # Disable warnings from HF.\n", + "\n", + "# This is needed for vLLM to use multiple GPUs in a notebook.\n", + "# If you're not running in a notebook, you can ignore this.\n", + "os.environ[\"VLLM_WORKER_MULTIPROC_METHOD\"] = \"spawn\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dataset\n", + "\n", + "The dataset we'll use for this notebook is `oumi-ai/oumi-letter-count`, which can be found on [HF Datasets](https://huggingface.co/datasets/oumi-ai/oumi-letter-count). Its prompts ask to count the letters in various English words, with metadata in each example containing the correct count. We use the `train` split for training and the `test` split for evaluation. We'll use an Oumi dataset class, `LetterCountGrpoDataset`, to load and preprocess the HF Dataset. The following code displays an example prompt:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2025-04-10 09:36:55,268][oumi][rank0][pid:9994][MainThread][INFO]][base_map_dataset.py:91] Creating map dataset (type: LetterCountGrpoDataset)... dataset_name: 'oumi-ai/oumi-letter-count'\n", + "[2025-04-10 09:37:00,027][oumi][rank0][pid:9994][MainThread][INFO]][base_map_dataset.py:487] Dataset Info:\n", + "\tSplit: validation\n", + "\tVersion: 0.0.0\n", + "\tDataset size: 22894322\n", + "\tDownload size: 5697295\n", + "\tSize: 28591617 bytes\n", + "\tRows: 10000\n", + "\tColumns: ['conversation_id', 'messages', 'metadata']\n", + "[2025-04-10 09:37:00,248][oumi][rank0][pid:9994][MainThread][INFO]][base_map_dataset.py:426] Loaded DataFrame with shape: (10000, 3). Columns:\n", + "conversation_id object\n", + "messages object\n", + "metadata object\n", + "dtype: object\n", + "--------------------------------------------------------------------------------\n", + "Sample:\n", + "{'conversation_id': 'oumi_letter_count_0',\n", + " 'messages': [{'content': \"Could you determine the count of 'l's in \"\n", + " \"'substantial'?\",\n", + " 'role': 'user'},\n", + " {'content': 'Your final answer should be written as digits and '\n", + " 'formatted as \"\\\\boxed{your_answer}\". For example, '\n", + " 'if the answer is 42, make sure to output '\n", + " '\"\\\\boxed{42}\".',\n", + " 'role': 'system'}],\n", + " 'metadata': {'letter': 'l',\n", + " 'letter_count_integer': 1,\n", + " 'letter_count_string': 'one',\n", + " 'unformatted_prompt': 'Could you determine the count of '\n", + " '{letter}s in {word}?',\n", + " 'word': 'substantial'}}\n" + ] + } + ], + "source": [ + "from pprint import pprint\n", + "\n", + "from oumi.datasets.grpo.letter_count import LetterCountGrpoDataset\n", + "\n", + "dataset = LetterCountGrpoDataset(split=\"validation\")\n", + "print(\"-\" * 80)\n", + "print(\"Sample:\")\n", + "pprint(dataset.conversation(0).to_dict())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation\n", + "\n", + "First, we'll evaluate how various models perform on the letter counting task. We'll evaluate frontier models by calling their respective remote API, and Llama 3.2 3B by running local inference on it using vLLM.\n", + "\n", + "We've already defined a custom evaluation function in Oumi which runs inference on the above dataset, extracts the answer from the model response, and calculates various metrics such as accuracy. This function is defined at `src/oumi/evaluation/registry/count_letters_task.py` ([GitHub link](https://github.com/oumi-ai/oumi/blob/main/src/oumi/evaluation/registry/count_letters_task.py)), and we print its contents below for reference." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "@register_evaluation_function(\"count_letters\")\n", + "def count_letters(\n", + " task_params: EvaluationTaskParams,\n", + " inference_engine: BaseInferenceEngine,\n", + ") -> dict[str, Any]:\n", + " \"\"\"Custom evaluation function registered as `count_letters`.\"\"\"\n", + " dataset = LetterCountGrpoDataset(split=\"test\")\n", + " # TODO: OPE-1155: Add support for using Oumi dataset code to create the dataset.\n", + " # dataset = build_dataset(\"oumi-ai/oumi-letter-count\", tokenizer=None, sample_count=10) # noqa: E501\n", + " # dataset = build_dataset(\"oumi-ai/berrybench-v0.1.0\", tokenizer=None, sample_count=10) # noqa: E501\n", + " num_samples = task_params.num_samples\n", + " if num_samples is None:\n", + " num_samples = len(dataset)\n", + " input_conversations = [dataset.conversation(i) for i in range(num_samples)]\n", + " conversations = inference_engine.infer(input_conversations)\n", + " logger.info(f\"Finished inference on {len(conversations)} conversations!\")\n", + " if len(conversations) > 0:\n", + " logger.info(f\"Sample conversation: {conversations[0]}\")\n", + "\n", + " count = 0 # The number of examples with correct answers extracted.\n", + " total = 0 # All examples.\n", + " valid_count = 0 # The number of examples with valid answers extracted.\n", + " for i, conversation in enumerate(conversations):\n", + " total += 1\n", + " # Grab the model's response\n", + " response = conversation.last_message()\n", + " # Ignore cases where model didn't respond or it's a multimodal response.\n", + " # For now, we focus on text-only responses.\n", + " if not response or not isinstance(response.content, str):\n", + " continue\n", + " # Count the example as correct if the extracted prediction is correct.\n", + " prediction = _extract_prediction(response.content)\n", + " if prediction is None:\n", + " continue\n", + " valid_count += 1\n", + " if prediction == conversation.metadata[\"letter_count_integer\"]:\n", + " count += 1\n", + "\n", + " return {\n", + " # Accuracy across all examples.\n", + " \"accuracy\": count / total,\n", + " # Accuracy when only counting examples with properly extracted answers.\n", + " \"properly_extracted_accuracy\": count / valid_count,\n", + " \"num_samples\": num_samples,\n", + " # These three values sum up to num_samples.\n", + " \"num_correct_answers\": count,\n", + " \"num_incorrect_answers\": valid_count - count,\n", + " \"num_invalid_answers\": total - valid_count,\n", + " }\n", + "\n" + ] + } + ], + "source": [ + "import inspect\n", + "\n", + "from oumi.evaluation.registry.count_letters_task import count_letters\n", + "\n", + "print(inspect.getsource(count_letters))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the following section, you can select which models you want to evaluate. You can lower `NUM_SAMPLES` to reduce cost when calling remote APIs, with the downside of noisier results." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "NUM_SAMPLES = 100\n", + "# We set an environment variable to be used at the end of the Colab.\n", + "os.environ[\"NUM_SAMPLES\"] = str(NUM_SAMPLES)\n", + "\n", + "model_names = [\n", + " \"llama_3b\",\n", + " # Uncomment any models you wish to evaluate - you can evaluate multiple at once.\n", + " # \"gpt_4o\",\n", + " # \"gemini_pro\",\n", + " # \"llama_405b\",\n", + " # \"claude_sonnet\",\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "❗**NOTICE:** If running this notebook on Colab, delete the following line: `inference_engine: VLLM`" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing letter_counting_tutorial/llama_3b_eval.yaml\n" + ] + } + ], + "source": [ + "%%writefile $tutorial_dir/llama_3b_eval.yaml\n", + "\n", + "# We save this config as a YAML file as we'll use it again at the end of the notebook.\n", + "model:\n", + " model_name: \"meta-llama/Llama-3.2-3B-Instruct\"\n", + " model_max_length: 131072\n", + " torch_dtype_str: \"bfloat16\"\n", + " attn_implementation: \"sdpa\"\n", + " trust_remote_code: True\n", + "\n", + "inference_engine: VLLM\n", + "\n", + "generation:\n", + " max_new_tokens: 2048\n", + "\n", + "tasks:\n", + " - evaluation_backend: custom\n", + " task_name: count_letters\n", + "\n", + "output_dir: \"letter_counting_tutorial/evaluation/llama_3b\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# EvaluationConfig for various models.\n", + "# Note that Llama 3B uses the local VLLM inference engines, while the others use various\n", + "# remote engines.\n", + "\n", + "with open(f\"{tutorial_dir}/llama_3b_eval.yaml\") as f:\n", + " llama_3b_yaml = f.read()\n", + "\n", + "configs = {\n", + " \"llama_3b\": llama_3b_yaml,\n", + " \"gpt_4o\": \"\"\"\n", + " model:\n", + " model_name: \"gpt-4o\"\n", + "\n", + " inference_engine: OPENAI\n", + "\n", + " inference_remote_params:\n", + " api_key_env_varname: \"OPENAI_API_KEY\"\n", + " max_retries: 3\n", + " num_workers: 100\n", + " politeness_policy: 60\n", + " connection_timeout: 300\n", + "\n", + " generation:\n", + " max_new_tokens: 8192\n", + " temperature: 0.0\n", + "\n", + " tasks:\n", + " - evaluation_backend: custom\n", + " task_name: count_letters\n", + "\n", + " output_dir: \"letter_counting_tutorial/evaluation/gpt_4o\"\n", + " \"\"\",\n", + " \"gemini_pro\": \"\"\"\n", + " model:\n", + " model_name: \"gemini-2.5-pro-preview-03-25\"\n", + "\n", + " inference_engine: GOOGLE_GEMINI\n", + "\n", + " inference_remote_params:\n", + " api_key_env_varname: \"GEMINI_API_KEY\"\n", + " max_retries: 3\n", + " num_workers: 2\n", + " politeness_policy: 60\n", + " connection_timeout: 300\n", + "\n", + " generation:\n", + " max_new_tokens: 8192\n", + " temperature: 0.0\n", + "\n", + " tasks:\n", + " - evaluation_backend: custom\n", + " task_name: count_letters\n", + "\n", + " output_dir: \"letter_counting_tutorial/evaluation/gemini_pro\"\n", + " \"\"\",\n", + " \"llama_405b\": f\"\"\"\n", + " model:\n", + " model_name: \"meta/llama-3.1-405b-instruct-maas\"\n", + "\n", + " inference_engine: GOOGLE_VERTEX\n", + "\n", + " inference_remote_params:\n", + " api_url: \"https://{REGION}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{REGION}/endpoints/openapi/chat/completions\"\n", + " max_retries: 3\n", + " num_workers: 10\n", + " politeness_policy: 60\n", + " connection_timeout: 300\n", + "\n", + " generation:\n", + " max_new_tokens: 8192\n", + " temperature: 0.0\n", + "\n", + " tasks:\n", + " - evaluation_backend: custom\n", + " task_name: count_letters\n", + "\n", + " output_dir: \"letter_counting_tutorial/evaluation/llama_405b\"\n", + " \"\"\",\n", + " \"claude_sonnet\": \"\"\"\n", + " model:\n", + " model_name: \"claude-3-7-sonnet-latest\"\n", + "\n", + " inference_engine: ANTHROPIC\n", + "\n", + " inference_remote_params:\n", + " api_key_env_varname: \"ANTHROPIC_API_KEY\"\n", + " max_retries: 3\n", + " num_workers: 5\n", + " politeness_policy: 65\n", + " connection_timeout: 300\n", + "\n", + " generation:\n", + " max_new_tokens: 8192\n", + " temperature: 0.0\n", + "\n", + " tasks:\n", + " - evaluation_backend: custom\n", + " task_name: count_letters\n", + "\n", + " output_dir: \"letter_counting_tutorial/evaluation/claude_sonnet\"\n", + " \"\"\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processed prompts: 100%|██████████| 100/100 [00:03<00:00, 26.70it/s, est. speed input: 2430.26 toks/s, output: 635.19 toks/s] \n", + "[2025-04-10 09:38:26,617][oumi][rank0][pid:9994][MainThread][INFO]][count_letters_task.py:53] Finished inference on 100 conversations!\n", + "[2025-04-10 09:38:26,618][oumi][rank0][pid:9994][MainThread][INFO]][count_letters_task.py:55] Sample conversation: conversation_id='oumi_letter_count_0' messages=[USER: Look through 'perivaginal' and count the 'n's., SYSTEM: Your final answer should be written as digits and formatted as \"\\boxed{your_answer}\". For example, if the answer is 42, make sure to output \"\\boxed{42}\"., ASSISTANT: There are 2 'n's in 'perivaginal'. \n", + "\n", + "\\boxed{2}] metadata={'letter': 'n', 'letter_count_integer': 1, 'letter_count_string': 'one', 'unformatted_prompt': 'Look through {word} and count the {letter}s.', 'word': 'perivaginal'}\n" + ] + } + ], + "source": [ + "# Run evaluation on all specified models.\n", + "\n", + "from oumi.core.configs import EvaluationConfig\n", + "from oumi.core.evaluation import Evaluator\n", + "\n", + "results = {}\n", + "\n", + "for model_name in model_names:\n", + " # Create the evaluation config from the YAML string.\n", + " config_yaml: str = configs[model_name]\n", + " config = EvaluationConfig.from_str(config_yaml)\n", + " config.tasks[0].num_samples = NUM_SAMPLES\n", + "\n", + " # Run the evaluation.\n", + " evaluator = Evaluator()\n", + " evaluator_out = evaluator.evaluate(config)\n", + "\n", + " # # Record the results.\n", + " results[model_name] = evaluator_out[0].get_results()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total samples: 100\n", + "--------------------------------------------------------------------------------\n", + "Model: llama_3b\n", + "Accuracy: 24.00%\n", + "Num correct, incorrect, invalid: 24, 69, 7\n" + ] + } + ], + "source": [ + "# Print results.\n", + "\n", + "print(f\"Total samples: {NUM_SAMPLES}\")\n", + "for model_name, result in results.items():\n", + " print(\"-\" * 80)\n", + " print(f\"Model: {model_name}\")\n", + " print(f\"Accuracy: {result['accuracy']:.2%}\")\n", + " correct = result[\"num_correct_answers\"]\n", + " incorrect = result[\"num_incorrect_answers\"]\n", + " invalid = result[\"num_invalid_answers\"]\n", + " print(f\"Num correct, incorrect, invalid: {correct}, {incorrect}, {invalid}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GRPO\n", + "\n", + "Now, we train Llama 3.2 3B on the task of counting letters using the GRPO algorithm implemented by [HuggingFace's `trl` library](https://huggingface.co/docs/trl/en/index).\n", + "\n", + "Note that we can calculate a concrete reward for this task by comparing the answer extracted by the model with the correct answer. In the reward function defined in `src/oumi/datasets/grpo/rewards/count_letters_rewards.py` ([GitHub link](https://github.com/oumi-ai/oumi/blob/main/src/oumi/datasets/grpo/rewards/count_letters_rewards.py)), we calculate the reward to be `-abs(predicted_count - target_count)`. We use simple heuristics to extract the predicted count. The following cell prints out the reward function code." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# Copyright 2025 - Oumi\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\n", + "import re\n", + "from typing import Any, Optional\n", + "\n", + "from oumi.core.registry import RegistryType, register\n", + "\n", + "\n", + "def _extract_prediction(response: str) -> Optional[int]:\n", + " r\"\"\"Returns the numeric answer extracted from `\\boxed{...}`, or None otherwise.\"\"\"\n", + " regex_result = re.findall(r\"\\\\boxed\\{([-+]?\\d+)\\}\", response)\n", + " if not regex_result or len(regex_result) != 1:\n", + " return None\n", + " number_str = regex_result[0]\n", + " # Except clause shouldn't trigger because the regex should only find ints.\n", + " try:\n", + " return int(number_str)\n", + " except ValueError:\n", + " return None\n", + "\n", + "\n", + "def compute_letter_count_reward(completion: str, target_count: int) -> float:\n", + " \"\"\"Computes the rewards for counting the letters in a string.\n", + "\n", + " The last group of consecutive digits in the completion is assumed to be the letter\n", + " count. We're also assuming it's counting the correct letter. The reward is the\n", + " negative of the absolute difference between the count and the target count, plus 0.1\n", + " if the answer was properly formatted.\n", + "\n", + " For example, for the string \"There are 2 'r's in strawberry\", and the target count\n", + " is 3, the reward is -1.\n", + "\n", + " Args:\n", + " completion: The completion string from the LLM.\n", + " target_count: The target count of letters.\n", + "\n", + " Returns:\n", + " The reward value, calculated as the negative of the absolute difference between\n", + " the count and the target count. The count is assumed to be the last group of\n", + " consecutive digits in the completion string.\n", + " \"\"\"\n", + " count = _extract_prediction(completion)\n", + " formatting_reward = 0.1 if count is not None else 0\n", + " if count is None:\n", + " count = 0\n", + " return -abs(count - target_count) + formatting_reward\n", + "\n", + "\n", + "@register(\"count_letters\", RegistryType.REWARD_FUNCTION)\n", + "def _count_letters(\n", + " completions: list[list[dict[str, Any]]],\n", + " letter_count: list[int],\n", + " **kwargs: dict[str, Any],\n", + ") -> list[float]:\n", + " \"\"\"Custom reward function for counting letters in a string.\n", + "\n", + " For more details on custom reward functions used in trl's GRPOTrainer, see:\n", + " https://huggingface.co/docs/trl/main/en/grpo_trainer#using-a-custom-reward-function.\n", + "\n", + " Args:\n", + " completions: The list of completions from the LLM.\n", + " letter_count: The list of target count of letters.\n", + " kwargs: Unused.\n", + "\n", + " Returns:\n", + " The reward values for each completion, calculated as the negative of the\n", + " absolute difference between the count and the target count. The count is assumed\n", + " to be the last group of consecutive digits in the completion string.\n", + " \"\"\"\n", + " completions_strs = [c[0][\"content\"] for c in completions]\n", + " return [\n", + " compute_letter_count_reward(c, t)\n", + " for c, t in zip(completions_strs, letter_count)\n", + " ]\n" + ] + } + ], + "source": [ + "!cat ../src/oumi/datasets/grpo/rewards/count_letters_rewards.py" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 04-10 09:38:28 multiproc_worker_utils.py:141] Terminating local vLLM worker processes\n", + "\u001b[1;36m(VllmWorkerProcess pid=10493)\u001b[0;0m INFO 04-10 09:38:28 multiproc_worker_utils.py:253] Worker exiting\n", + "\u001b[1;36m(VllmWorkerProcess pid=10494)\u001b[0;0m INFO 04-10 09:38:28 multiproc_worker_utils.py:253] Worker exiting\n", + "\u001b[1;36m(VllmWorkerProcess pid=10492)\u001b[0;0m INFO 04-10 09:38:28 multiproc_worker_utils.py:253] Worker exiting\n" + ] + } + ], + "source": [ + "# Clean up to free-up GPU memory used for evaluation above\n", + "import gc\n", + "\n", + "import torch\n", + "\n", + "\n", + "def cleanup_memory():\n", + " \"\"\"Delete the evaluator and collect garbage.\"\"\"\n", + " global evaluator\n", + " if evaluator: # type: ignore\n", + " del evaluator\n", + " evaluator = None\n", + " for _ in range(3):\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + " torch.cuda.synchronize()\n", + "\n", + "\n", + "cleanup_memory()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "❗**NOTICE:** Set `training.enable_wandb` to True if you want to log your training run to Weights and Biases. In addition, you must also log into WandB, ex. by running `wandb login`.\n", + "\n", + "❗**NOTICE:** We only train for 2 steps for demonstration purposes. You can increase `max_steps`, or replace it with `num_train_epochs` to set your desired number of epochs." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing letter_counting_tutorial/grpo_train.yaml\n" + ] + } + ], + "source": [ + "%%writefile $tutorial_dir/grpo_train.yaml\n", + "\n", + "model:\n", + " model_name: \"meta-llama/Llama-3.2-3B-Instruct\"\n", + " model_max_length: 8192\n", + " torch_dtype_str: \"bfloat16\"\n", + " attn_implementation: \"sdpa\"\n", + "\n", + "data:\n", + " train:\n", + " datasets:\n", + " - dataset_name: \"oumi-ai/oumi-letter-count\"\n", + " split: \"train\"\n", + "\n", + "training:\n", + " trainer_type: \"TRL_GRPO\"\n", + " save_steps: 500\n", + " max_steps: 2\n", + " per_device_train_batch_size: 2\n", + " gradient_accumulation_steps: 1\n", + " learning_rate: 5e-5\n", + " lr_scheduler_type: \"cosine\"\n", + " warmup_steps: 20\n", + "\n", + " reward_functions: [\"count_letters\"]\n", + "\n", + " ddp_find_unused_parameters: False\n", + " optimizer: \"adafactor\"\n", + " compile: True\n", + "\n", + " grpo:\n", + " num_generations: 4\n", + "\n", + " dataloader_num_workers: \"auto\"\n", + " dataloader_prefetch_factor: 32\n", + "\n", + " logging_steps: 10\n", + " output_dir: \"letter_counting_tutorial/llama_3b_grpo\"\n", + " # Set this to True if you want to log to Weights and Biases.\n", + " enable_wandb: False" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2025-04-10 09:38:32,770][oumi][rank0][pid:10890][MainThread][INFO]][distributed_run.py:276] Running the command: ['torchrun', '--nnodes=1', '--node-rank=0', '--nproc-per-node=4', '--master-addr=127.0.0.1', '--master-port=8007', '-m', 'oumi', 'train', '-c', 'letter_counting_tutorial/grpo_train.yaml']\n", + "\n", + "\u001b[32m ____ _ _ __ __ _____\u001b[0m\n", + "\u001b[32m / __ \\| | | | \\/ |_ _|\u001b[0m\n", + "\u001b[32m | | | | | | | \\ / | | |\u001b[0m\n", + "\u001b[32m | | | | | | | |\\/| | | |\u001b[0m\n", + "\u001b[32m | |__| | |__| | | | |_| |_\u001b[0m\n", + "\u001b[32m \\____/ \\____/|_| |_|_____|\u001b[0m\n", + "\u001b[2K\u001b[32m⠦\u001b[0m \u001b[32mLoading configuration...\u001b[0m\n", + "Model Parameters Summary:\n", + "🔢 Total parameters: 3,212,749,824\n", + "🔗 Embedding parameters: 394,002,432\n", + "🎯 Trainable parameters: 3,212,749,824\n", + "🔒 Frozen parameters: 0 (0.00%)\n", + "\n", + "[2025-04-10 09:38:47,784][oumi][rank0][pid:10894][MainThread][INFO]][base_map_dataset.py:91] Creating map dataset (type: LetterCountGrpoDataset)... dataset_name: 'oumi-ai/oumi-letter-count'\n", + "[2025-04-10 09:38:50,326][oumi][rank0][pid:10894][MainThread][INFO]][base_map_dataset.py:487] Dataset Info:\n", + "\tSplit: train\n", + "\tVersion: 0.0.0\n", + "\tDataset size: 22894322\n", + "\tDownload size: 5697295\n", + "\tSize: 28591617 bytes\n", + "\tRows: 100000\n", + "\tColumns: ['conversation_id', 'messages', 'metadata']\n", + "Generating train split: 8 examples [00:00, 2656.51 examples/s]\n", + "[2025-04-10 09:38:50,827][oumi][rank0][pid:10894][MainThread][INFO]][base_map_dataset.py:426] Loaded DataFrame with shape: (100000, 3). Columns:\n", + "conversation_id object\n", + "messages object\n", + "metadata object\n", + "dtype: object\n", + "[2025-04-10 09:38:50,837][oumi][rank0][pid:10894][MainThread][INFO]][base_map_dataset.py:312] LetterCountGrpoDataset: features=dict_keys(['prompt', 'letter_count'])\n", + "Generating train split: 100000 examples [00:08, 11203.44 examples/s]\n", + "[2025-04-10 09:39:12,603][oumi][rank0][pid:10894][MainThread][INFO]][base_map_dataset.py:376] Finished transforming dataset (LetterCountGrpoDataset)! Speed: 4594.28 examples/sec. Examples: 100000. Duration: 21.8 sec. Transform workers: 1.\n", + "[2025-04-10 09:39:14,799][oumi][rank0][pid:10894][MainThread][INFO]][train.py:419] Training init time: 31.849s\n", + "[2025-04-10 09:39:14,800][oumi][rank0][pid:10894][MainThread][INFO]][train.py:420] Starting training... (TrainerType.TRL_GRPO, transformers: 4.51.1)\n", + "{'train_runtime': 226.3313, 'train_samples_per_second': 0.071, 'train_steps_per_second': 0.009, 'train_loss': -0.5479733943939209, 'rewards/_count_letters': -1.0625, 'reward': -1.0625, 'reward_std': 0.9289332032203674, 'completion_length': 47.8125, 'kl': 0.0001125335693359375, 'epoch': 0.0}\n", + "100%|████████████████████████████████████████████| 2/2 [03:46<00:00, 113.16s/it]\n", + "[2025-04-10 09:43:01,562][oumi][rank0][pid:10894][MainThread][INFO]][train.py:427] Training is Complete.\n", + "[2025-04-10 09:43:01,562][oumi][rank0][pid:10894][MainThread][INFO]][device_utils.py:297] GPU Metrics After Training: GPU runtime info: None.\n", + "[2025-04-10 09:43:01,562][oumi][rank0][pid:10894][MainThread][INFO]][torch_utils.py:136] Peak GPU memory usage: 29.84 GB\n", + "[2025-04-10 09:43:01,562][oumi][rank0][pid:10894][MainThread][INFO]][train.py:434] Saving final state...\n", + "[2025-04-10 09:43:01,563][oumi][rank0][pid:10894][MainThread][INFO]][train.py:439] Saving final model...\n", + "[2025-04-10 09:43:14,924][oumi][rank0][pid:10894][MainThread][INFO]][hf_trainer.py:116] Model has been saved at letter_counting_tutorial/llama_3b_grpo\n", + "[2025-04-10 09:43:15,458][oumi][rank0][pid:10894][MainThread][INFO]][train.py:446] \n", + "\n", + "» We're always looking for feedback. What's one thing we can improve? https://oumi.ai/feedback\n", + "[2025-04-10 09:43:23,534][oumi][rank0][pid:10890][MainThread][INFO]][distributed_run.py:295] Successfully completed! (Rank: 0. Duration: 290.8 sec)\n" + ] + } + ], + "source": [ + "!oumi distributed torchrun -m oumi train -c $tutorial_dir/grpo_train.yaml" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluating our Trained Model\n", + "\n", + "Let's now evaluate our trained model to see if it improved on the letter counting task. Note that it may not improve much, since we trained it for a relatively short time.\n", + "\n", + "Below, we demonstrate an alternative method of running evaluation with the `oumi` CLI. We use the same Llama 3B evaluation config we used above, with the only change being pointing it at the model we just trained.\n", + "\n", + "First, we need to reset the notebook to clear variables from our previous vLLM run." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "%reset -f" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[32m ____ _ _ __ __ _____\u001b[0m\n", + "\u001b[32m / __ \\| | | | \\/ |_ _|\u001b[0m\n", + "\u001b[32m | | | | | | | \\ / | | |\u001b[0m\n", + "\u001b[32m | | | | | | | |\\/| | | |\u001b[0m\n", + "\u001b[32m | |__| | |__| | | | |_| |_\u001b[0m\n", + "\u001b[32m \\____/ \\____/|_| |_|_____|\u001b[0m\n", + "\u001b[2K\u001b[32m⠴\u001b[0m \u001b[32mLoading configuration...\u001b[0m0m\n", + "\u001b[2K\u001b[32m⠋\u001b[0m \u001b[32mRunning evaluation...\u001b[0m[2025-04-10 09:47:15,521][oumi][rank0][pid:16694][MainThread][INFO]][models.py:482] Using the model's built-in chat template for model 'letter_counting_tutorial/llama_3b_grpo'.\n", + "\u001b[2KINFO 04-10 09:47:15 __init__.py:207] Automatically detected platform cuda.\n", + "\u001b[2KINFO 04-10 09:47:33 model_runner.py:1110] Starting to load model \n", + "letter_counting_tutorial/llama_3b_grpo...\n", + "\u001b[32m⠸\u001b[0m \u001b[32mRunning evaluation...\u001b[0m\u001b[1;36m(VllmWorkerProcess pid=16718)\u001b[0;0m INFO 04-10 09:47:33 model_runner.py:1110] Starting to load model letter_counting_tutorial/llama_3b_grpo...\n", + "\u001b[32m⠴\u001b[0m \u001b[32mRunning evaluation...\u001b[0m\u001b[1;36m(VllmWorkerProcess pid=16719)\u001b[0;0m INFO 04-10 09:47:34 model_runner.py:1115] Loading model weights took 1.5341 GB\n", + "\u001b[1;36m(VllmWorkerProcess pid=16718)\u001b[0;0m INFO 04-10 09:47:34 model_runner.py:1115] Loading model weights took 1.5341 GB\n", + "\n", + "\u001b[2K\u001b[32m⠹\u001b[0m \u001b[32mRunning evaluation...\u001b[0m[2025-04-10 09:47:46,962][oumi][rank0][pid:16694][MainThread][INFO]][base_map_dataset.py:91] Creating map dataset (type: LetterCountGrpoDataset)... dataset_name: 'oumi-ai/oumi-letter-count'\n", + "\u001b[2K\u001b[32m⠋\u001b[0m \u001b[32mRunning evaluation...\u001b[0m[2025-04-10 09:47:49,180][oumi][rank0][pid:16694][MainThread][INFO]][base_map_dataset.py:487] Dataset Info:\n", + "\tSplit: test\n", + "\tVersion: 0.0.0\n", + "\tDataset size: 22894322\n", + "\tDownload size: 5697295\n", + "\tSize: 28591617 bytes\n", + "\tRows: 20000\n", + "\tColumns: ['conversation_id', 'messages', 'metadata']\n", + "\u001b[2K\u001b[32m⠏\u001b[0m \u001b[32mRunning evaluation...\u001b[0m[2025-04-10 09:47:49,820][oumi][rank0][pid:16694][MainThread][INFO]][base_map_dataset.py:426] Loaded DataFrame with shape: (20000, 3). Columns:\n", + "conversation_id object\n", + "messages object\n", + "metadata object\n", + "dtype: object\n", + "\u001b[2KINFO 04-10 09:47:49 chat_utils.py:332] Detected the chat template content format\n", + "to be 'string'. You can set `--chat-template-content-format` to override this.\n", + "\u001b[2KProcessed prompts: \u001b[1;36m100\u001b[0m%|#| \u001b[1;36m100\u001b[0m/\u001b[1;36m100\u001b[0m \u001b[1m[\u001b[0m\u001b[1;92m00:03\u001b[0m<\u001b[1;92m00:00\u001b[0m, \u001b[1;36m28.\u001b[0m84it/s, est. speed input: \u001b[1;36m26\u001b[0m\n", + "\u001b[32m⠸\u001b[0m \u001b[32mRunning evaluation...\u001b[0m[2025-04-10 09:47:53,387][oumi][rank0][pid:16694][MainThread][INFO]][count_letters_task.py:53] Finished inference on 100 conversations!\n", + "[2025-04-10 09:47:53,387][oumi][rank0][pid:16694][MainThread][INFO]][count_letters_task.py:55] Sample conversation: conversation_id='oumi_letter_count_0' messages=[USER: Look through 'perivaginal' and count the 'n's., SYSTEM: Your final answer should be written as digits and formatted as \"\\boxed{your_answer}\". For example, if the answer is 42, make sure to output \"\\boxed{42}\"., ASSISTANT: There are 2 'n's in 'perivaginal'. \n", + "\n", + "\\boxed{2}] metadata={'letter': 'n', 'letter_count_integer': 1, 'letter_count_string': 'one', 'unformatted_prompt': 'Look through {word} and count the {letter}s.', 'word': 'perivaginal'}\n", + "\u001b[2KINFO 04-10 09:47:53 multiproc_worker_utils.py:141] Terminating local vLLM worker\n", + "processes\n", + "\u001b[32m⠧\u001b[0m \u001b[32mRunning evaluation...\u001b[0m\u001b[1;36m(VllmWorkerProcess pid=16717)\u001b[0;0m INFO 04-10 09:47:53 multiproc_worker_utils.py:253] Worker exiting\n", + "\u001b[2K\u001b[32m⠧\u001b[0m \u001b[32mRunning evaluation...\u001b[0m\n", + "\u001b[1A\u001b[2K\u001b[1;35m Evaluation Results \u001b[0m\n", + "┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mBenchmark \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mMetric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mScore \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mStd Error\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36mcount_letters\u001b[0m\u001b[36m \u001b[0m│\u001b[33m \u001b[0m\u001b[33mAccuracy \u001b[0m\u001b[33m \u001b[0m│\u001b[32m \u001b[0m\u001b[32m28.00%\u001b[0m\u001b[32m \u001b[0m│\u001b[2m \u001b[0m\u001b[2m- \u001b[0m\u001b[2m \u001b[0m│\n", + "├───────────────┼─────────────────────────────┼────────┼───────────┤\n", + "│\u001b[36m \u001b[0m\u001b[36mcount_letters\u001b[0m\u001b[36m \u001b[0m│\u001b[33m \u001b[0m\u001b[33mProperly Extracted Accuracy\u001b[0m\u001b[33m \u001b[0m│\u001b[32m \u001b[0m\u001b[32m31.46%\u001b[0m\u001b[32m \u001b[0m│\u001b[2m \u001b[0m\u001b[2m- \u001b[0m\u001b[2m \u001b[0m│\n", + "├───────────────┼─────────────────────────────┼────────┼───────────┤\n", + "│\u001b[36m \u001b[0m\u001b[36mcount_letters\u001b[0m\u001b[36m \u001b[0m│\u001b[33m \u001b[0m\u001b[33mNum Samples \u001b[0m\u001b[33m \u001b[0m│\u001b[32m \u001b[0m\u001b[32m100.00\u001b[0m\u001b[32m \u001b[0m│\u001b[2m \u001b[0m\u001b[2m- \u001b[0m\u001b[2m \u001b[0m│\n", + "├───────────────┼─────────────────────────────┼────────┼───────────┤\n", + "│\u001b[36m \u001b[0m\u001b[36mcount_letters\u001b[0m\u001b[36m \u001b[0m│\u001b[33m \u001b[0m\u001b[33mNum Correct Answers \u001b[0m\u001b[33m \u001b[0m│\u001b[32m \u001b[0m\u001b[32m28.00 \u001b[0m\u001b[32m \u001b[0m│\u001b[2m \u001b[0m\u001b[2m- \u001b[0m\u001b[2m \u001b[0m│\n", + "├───────────────┼─────────────────────────────┼────────┼───────────┤\n", + "│\u001b[36m \u001b[0m\u001b[36mcount_letters\u001b[0m\u001b[36m \u001b[0m│\u001b[33m \u001b[0m\u001b[33mNum Incorrect Answers \u001b[0m\u001b[33m \u001b[0m│\u001b[32m \u001b[0m\u001b[32m61.00 \u001b[0m\u001b[32m \u001b[0m│\u001b[2m \u001b[0m\u001b[2m- \u001b[0m\u001b[2m \u001b[0m│\n", + "├───────────────┼─────────────────────────────┼────────┼───────────┤\n", + "│\u001b[36m \u001b[0m\u001b[36mcount_letters\u001b[0m\u001b[36m \u001b[0m│\u001b[33m \u001b[0m\u001b[33mNum Invalid Answers \u001b[0m\u001b[33m \u001b[0m│\u001b[32m \u001b[0m\u001b[32m11.00 \u001b[0m\u001b[32m \u001b[0m│\u001b[2m \u001b[0m\u001b[2m- \u001b[0m\u001b[2m \u001b[0m│\n", + "└───────────────┴─────────────────────────────┴────────┴───────────┘\n", + "[rank0]:[W410 09:47:58.432638348 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())\n" + ] + } + ], + "source": [ + "!oumi evaluate -c letter_counting_tutorial/llama_3b_eval.yaml \\\n", + " --model.model_name \"letter_counting_tutorial/llama_3b_grpo\" \\\n", + " --tasks.0.num_samples $NUM_SAMPLES \\\n", + " --output_dir \"letter_counting_tutorial/evaluation/llama_3_grpo\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "oumi", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/oumi/datasets/grpo/letter_count.py b/src/oumi/datasets/grpo/letter_count.py index 752a7ad3c..cef1b9c5c 100644 --- a/src/oumi/datasets/grpo/letter_count.py +++ b/src/oumi/datasets/grpo/letter_count.py @@ -20,9 +20,9 @@ from oumi.core.types.conversation import Conversation _SYSTEM_PROMPT = ( - "Your final answer should be written as digits and formatted as " + "Your final answer should be an integer written as digits and formatted as " r'"\boxed{your_answer}". For example, if the answer is 42, ' - r'make sure to output "\boxed{42}".' + r'you should output "\boxed{42}".' ) @@ -54,10 +54,11 @@ class LetterCountGrpoDataset(BaseExperimentalGrpoDataset): @override def transform(self, sample: pd.Series) -> dict: """Validate and transform the sample into Python `dict`.""" - # TODO: OPE-1122: Add system prompt to training. - # OPE-1158 seems to affect this, as the type of the input isn't consistent. + # Add system prompt before user prompt. + system_message = {"content": _SYSTEM_PROMPT, "role": "system"} + messages = [system_message, sample["messages"][0]] return { - "prompt": sample["messages"], + "prompt": messages, "letter_count": sample["metadata"]["letter_count_integer"], } @@ -74,8 +75,8 @@ def transform_conversation(self, sample: pd.Series) -> Conversation: """ # Example is already in conversation format and only needs light processing. sample_dict = sample.to_dict() - # Convert messages from np.ndarray to list. - sample_dict["messages"] = sample_dict["messages"].tolist() - # Add system prompt. - sample_dict["messages"].append({"content": _SYSTEM_PROMPT, "role": "system"}) + # Add system prompt before user prompt. + system_message = {"content": _SYSTEM_PROMPT, "role": "system"} + messages = [system_message, sample["messages"][0]] + sample_dict["messages"] = messages return Conversation.from_dict(sample_dict) diff --git a/src/oumi/datasets/grpo/rewards/count_letters_rewards.py b/src/oumi/datasets/grpo/rewards/count_letters_rewards.py index 1284935c7..3b50b6bd2 100644 --- a/src/oumi/datasets/grpo/rewards/count_letters_rewards.py +++ b/src/oumi/datasets/grpo/rewards/count_letters_rewards.py @@ -18,13 +18,12 @@ from oumi.core.registry import RegistryType, register -def _find_last_number(s: str) -> Optional[int]: - """Finds the last number (aka adjacent digits) in a string, or None if not found.""" - # Find all groups of consecutive digits in the string. - regex_result = re.findall(r"\d+", s) - if not regex_result: +def _extract_prediction(response: str) -> Optional[int]: + r"""Returns the numeric answer extracted from `\boxed{...}`, or None otherwise.""" + regex_result = re.findall(r"\\boxed\{([-+]?\d+)\}", response) + if not regex_result or len(regex_result) != 1: return None - number_str = regex_result[-1] + number_str = regex_result[0] # Except clause shouldn't trigger because the regex should only find ints. try: return int(number_str) @@ -32,12 +31,13 @@ def _find_last_number(s: str) -> Optional[int]: return None -def compute_letter_count_reward(completion: str, target_count: int) -> int: +def compute_letter_count_reward(completion: str, target_count: int) -> float: """Computes the rewards for counting the letters in a string. The last group of consecutive digits in the completion is assumed to be the letter count. We're also assuming it's counting the correct letter. The reward is the - negative of the absolute difference between the count and the target count. + negative of the absolute difference between the count and the target count, plus 0.1 + if the answer was properly formatted. For example, for the string "There are 2 'r's in strawberry", and the target count is 3, the reward is -1. @@ -51,10 +51,11 @@ def compute_letter_count_reward(completion: str, target_count: int) -> int: the count and the target count. The count is assumed to be the last group of consecutive digits in the completion string. """ - count = _find_last_number(completion) + count = _extract_prediction(completion) + formatting_reward = 0.1 if count is not None else 0 if count is None: count = 0 - return -abs(count - target_count) + return -abs(count - target_count) + formatting_reward @register("count_letters", RegistryType.REWARD_FUNCTION) @@ -62,7 +63,7 @@ def _count_letters( completions: list[list[dict[str, Any]]], letter_count: list[int], **kwargs: dict[str, Any], -) -> list[int]: +) -> list[float]: """Custom reward function for counting letters in a string. For more details on custom reward functions used in trl's GRPOTrainer, see: diff --git a/src/oumi/evaluation/registry/count_letters_task.py b/src/oumi/evaluation/registry/count_letters_task.py index 9020237c7..01e48d0d3 100644 --- a/src/oumi/evaluation/registry/count_letters_task.py +++ b/src/oumi/evaluation/registry/count_letters_task.py @@ -24,7 +24,7 @@ def _extract_prediction(response: str) -> Optional[int]: r"""Returns the numeric answer extracted from `\boxed{...}`, or None otherwise.""" - regex_result = re.findall(r"\\boxed\{(\d+)\}", response) + regex_result = re.findall(r"\\boxed\{([-+]?\d+)\}", response) if not regex_result or len(regex_result) != 1: return None number_str = regex_result[0] diff --git a/tests/unit/datasets/grpo/rewards/test_count_letters_rewards.py b/tests/unit/datasets/grpo/rewards/test_count_letters_rewards.py index 13b53bfd9..c25789de3 100644 --- a/tests/unit/datasets/grpo/rewards/test_count_letters_rewards.py +++ b/tests/unit/datasets/grpo/rewards/test_count_letters_rewards.py @@ -6,11 +6,21 @@ @pytest.mark.parametrize( "s,target_count,reward", [ - ("foo bar 1", 1, 0), - ("foo bar1", 1, 0), - ("foo bar one", 1, -1), - ("11 1", 1, 0), - ("The number of 'r's in strawberry is 10.", 3, -7), + # No valid answer + ("foo bar 1", 1, -1), + # Valid correct answer + (r"\boxed{1}", 1, 0.1), + # Valid correct answer + (r"\boxed{+1}", 1, 0.1), + # Valid incorrect answer + (r"\boxed{4}", 1, -2.9), + # Valid incorrect answer + (r"\boxed{-1}", 1, -1.9), + # Invalid answer + (r"The answer is \boxed{one}", 0, 0), + # Conflicting answers + (r"\boxed{1} \boxed{2}", 1, -1), + (r"The number of 'r's in strawberry is \boxed{10}.", 3, -6.9), ], ) def test_compute_soft_target_token_length_reward(s, target_count, reward): From 55adedc0e72aa407cc4f387b07a9afdd39b1d386 Mon Sep 17 00:00:00 2001 From: Kostas Date: Fri, 11 Apr 2025 11:24:25 -0700 Subject: [PATCH 09/15] [Remote Inference] Update Default Params (#1630) --- .../core/configs/params/generation_params.py | 4 +- src/oumi/core/configs/params/remote_params.py | 2 +- .../inference/anthropic_inference_engine.py | 5 ++ src/oumi/inference/gcp_inference_engine.py | 5 ++ src/oumi/inference/gemini_inference_engine.py | 7 ++- src/oumi/inference/openai_inference_engine.py | 13 ++++- src/oumi/inference/remote_inference_engine.py | 6 ++- .../test_anthropic_inference_engine.py | 8 +++ .../inference/test_gcp_inference_engine.py | 8 +++ .../inference/test_gemini_inference_engine.py | 8 +++ .../unit/inference/test_generation_params.py | 2 +- .../inference/test_openai_inference_engine.py | 51 ++++++++++++++++++- 12 files changed, 110 insertions(+), 9 deletions(-) diff --git a/src/oumi/core/configs/params/generation_params.py b/src/oumi/core/configs/params/generation_params.py index a8ab90a34..9f2297ba9 100644 --- a/src/oumi/core/configs/params/generation_params.py +++ b/src/oumi/core/configs/params/generation_params.py @@ -21,11 +21,11 @@ @dataclass class GenerationParams(BaseParams): - max_new_tokens: int = 256 + max_new_tokens: int = 1024 """The maximum number of new tokens to generate. This limits the length of the generated text to prevent excessively long outputs. - Default is 256 tokens. + Default is 1024 tokens. """ batch_size: Optional[int] = 1 diff --git a/src/oumi/core/configs/params/remote_params.py b/src/oumi/core/configs/params/remote_params.py index f7f446b69..e1b77f008 100644 --- a/src/oumi/core/configs/params/remote_params.py +++ b/src/oumi/core/configs/params/remote_params.py @@ -36,7 +36,7 @@ class RemoteParams(BaseParams): max_retries: int = 3 """Maximum number of retries to attempt when calling an API.""" - connection_timeout: float = 20.0 + connection_timeout: float = 300.0 """Timeout in seconds for a request to an API.""" num_workers: int = 1 diff --git a/src/oumi/inference/anthropic_inference_engine.py b/src/oumi/inference/anthropic_inference_engine.py index bf14e6dff..8cd5502a0 100644 --- a/src/oumi/inference/anthropic_inference_engine.py +++ b/src/oumi/inference/anthropic_inference_engine.py @@ -151,3 +151,8 @@ def get_supported_params(self) -> set[str]: "temperature", "top_p", } + + @override + def _default_remote_params(self) -> RemoteParams: + """Returns the default remote parameters.""" + return RemoteParams(num_workers=5, politeness_policy=60.0) diff --git a/src/oumi/inference/gcp_inference_engine.py b/src/oumi/inference/gcp_inference_engine.py index 4b561b5c1..9d7ce23d7 100644 --- a/src/oumi/inference/gcp_inference_engine.py +++ b/src/oumi/inference/gcp_inference_engine.py @@ -71,6 +71,11 @@ def _get_request_headers( } return headers + @override + def _default_remote_params(self) -> RemoteParams: + """Returns the default remote parameters.""" + return RemoteParams(num_workers=10, politeness_policy=60.0) + @override def _convert_conversation_to_api_input( self, diff --git a/src/oumi/inference/gemini_inference_engine.py b/src/oumi/inference/gemini_inference_engine.py index 748c9be3e..eb0cae077 100644 --- a/src/oumi/inference/gemini_inference_engine.py +++ b/src/oumi/inference/gemini_inference_engine.py @@ -16,7 +16,7 @@ from typing_extensions import override -from oumi.core.configs import GenerationParams, ModelParams +from oumi.core.configs import GenerationParams, ModelParams, RemoteParams from oumi.core.types.conversation import Conversation from oumi.inference.gcp_inference_engine import ( _convert_guided_decoding_config_to_api_input, @@ -100,3 +100,8 @@ def infer_batch( str: The batch ID. """ raise NotImplementedError("Batch inference is not supported for Gemini API.") + + @override + def _default_remote_params(self) -> RemoteParams: + """Returns the default remote parameters.""" + return RemoteParams(num_workers=2, politeness_policy=60.0) diff --git a/src/oumi/inference/openai_inference_engine.py b/src/oumi/inference/openai_inference_engine.py index 82d1fbd3e..4539d1daa 100644 --- a/src/oumi/inference/openai_inference_engine.py +++ b/src/oumi/inference/openai_inference_engine.py @@ -17,7 +17,7 @@ from typing_extensions import override -from oumi.core.configs import GenerationParams, ModelParams +from oumi.core.configs import GenerationParams, ModelParams, RemoteParams from oumi.core.types.conversation import Conversation from oumi.inference.remote_inference_engine import RemoteInferenceEngine @@ -56,13 +56,22 @@ def _convert_conversation_to_api_input( Returns: Dict[str, Any]: A dictionary representing the OpenAI input. """ - # o1-preview does NOT support logit_bias. if model_params.model_name == "o1-preview": generation_params = copy.deepcopy(generation_params) + + # o1-preview does NOT support logit_bias. generation_params.logit_bias = {} + # o1-preview only supports temperature = 1. + generation_params.temperature = 1.0 + return super()._convert_conversation_to_api_input( conversation=conversation, generation_params=generation_params, model_params=model_params, ) + + @override + def _default_remote_params(self) -> RemoteParams: + """Returns the default remote parameters.""" + return RemoteParams(num_workers=50, politeness_policy=60.0) diff --git a/src/oumi/inference/remote_inference_engine.py b/src/oumi/inference/remote_inference_engine.py index cdf515477..9b037e966 100644 --- a/src/oumi/inference/remote_inference_engine.py +++ b/src/oumi/inference/remote_inference_engine.py @@ -223,7 +223,7 @@ def __init__( if remote_params: remote_params = copy.deepcopy(remote_params) else: - remote_params = RemoteParams() + remote_params = self._default_remote_params() if not remote_params.api_url: remote_params.api_url = self.base_url @@ -232,6 +232,10 @@ def __init__( self._remote_params = remote_params self._remote_params.finalize_and_validate() + def _default_remote_params(self) -> RemoteParams: + """Returns the default remote parameters.""" + return RemoteParams() + @staticmethod def _get_list_of_message_json_dicts( messages: list[Message], diff --git a/tests/unit/inference/test_anthropic_inference_engine.py b/tests/unit/inference/test_anthropic_inference_engine.py index e7f7f0349..2755b6837 100644 --- a/tests/unit/inference/test_anthropic_inference_engine.py +++ b/tests/unit/inference/test_anthropic_inference_engine.py @@ -74,3 +74,11 @@ def test_get_request_headers(anthropic_engine): assert result["Content-Type"] == "application/json" assert result["anthropic-version"] == AnthropicInferenceEngine.anthropic_version assert result["X-API-Key"] == "test_api_key" + + +def test_remote_params_defaults(): + anthropic_engine = AnthropicInferenceEngine( + model_params=ModelParams(model_name="some_model"), + ) + assert anthropic_engine._remote_params.num_workers == 5 + assert anthropic_engine._remote_params.politeness_policy == 60.0 diff --git a/tests/unit/inference/test_gcp_inference_engine.py b/tests/unit/inference/test_gcp_inference_engine.py index cb2d04bc4..8bddf10c7 100644 --- a/tests/unit/inference/test_gcp_inference_engine.py +++ b/tests/unit/inference/test_gcp_inference_engine.py @@ -181,3 +181,11 @@ def test_infer_from_file(gcp_engine, conversation, inference_config, tmp_path): assert len(results) == 1 assert results[0] == conversation + + +def test_remote_params_defaults(): + gcp_engine = GoogleVertexInferenceEngine( + model_params=ModelParams(model_name="some_model"), + ) + assert gcp_engine._remote_params.num_workers == 10 + assert gcp_engine._remote_params.politeness_policy == 60.0 diff --git a/tests/unit/inference/test_gemini_inference_engine.py b/tests/unit/inference/test_gemini_inference_engine.py index c4f4e28b3..14fe23e42 100644 --- a/tests/unit/inference/test_gemini_inference_engine.py +++ b/tests/unit/inference/test_gemini_inference_engine.py @@ -194,3 +194,11 @@ def test_gemini_batch_prediction_disabled(gemini_engine, inference_config): with pytest.raises(NotImplementedError): gemini_engine.infer_batch([conversation], inference_config) + + +def test_remote_params_defaults(): + gemini_engine = GoogleGeminiInferenceEngine( + model_params=ModelParams(model_name="some_model"), + ) + assert gemini_engine._remote_params.num_workers == 2 + assert gemini_engine._remote_params.politeness_policy == 60.0 diff --git a/tests/unit/inference/test_generation_params.py b/tests/unit/inference/test_generation_params.py index 84f676187..68a4b747a 100644 --- a/tests/unit/inference/test_generation_params.py +++ b/tests/unit/inference/test_generation_params.py @@ -231,7 +231,7 @@ def test_generation_params_defaults_used_in_inference( mock_infer.assert_called_once() called_params = mock_infer.call_args[0][1].generation - assert called_params.max_new_tokens == 256 + assert called_params.max_new_tokens == 1024 assert called_params.temperature == 0.0 assert called_params.top_p == 1.0 assert called_params.frequency_penalty == 0.0 diff --git a/tests/unit/inference/test_openai_inference_engine.py b/tests/unit/inference/test_openai_inference_engine.py index f9780d12a..6d55fb30f 100644 --- a/tests/unit/inference/test_openai_inference_engine.py +++ b/tests/unit/inference/test_openai_inference_engine.py @@ -1,6 +1,7 @@ import pytest -from oumi.core.configs import ModelParams, RemoteParams +from oumi.core.configs import GenerationParams, ModelParams, RemoteParams +from oumi.core.types.conversation import Conversation, Message, Role from oumi.inference.openai_inference_engine import OpenAIInferenceEngine @@ -34,3 +35,51 @@ def test_openai_init_default_params(): assert engine._model_params.model_name == "gpt-4" assert engine._remote_params.api_url == "https://api.openai.com/v1/chat/completions" assert engine._remote_params.api_key_env_varname == "OPENAI_API_KEY" + + +@pytest.mark.parametrize( + ( + "model_name," + "logit_bias," + "temperature," + "expected_logit_bias," + "expected_temperature," + ), + [ + ("some_model", {"token": 0.0}, 0.0, {"token": 0.0}, 0.0), + ("o1-preview", {"token": 0.0}, 0.0, {}, 1.0), + ], + ids=[ + "test_default_params", + "test_default_params_o1_preview", + ], +) +def test_default_params( + model_name, logit_bias, temperature, expected_logit_bias, expected_temperature +): + openai_engine = OpenAIInferenceEngine( + model_params=ModelParams(model_name=model_name), + generation_params=GenerationParams( + temperature=temperature, + logit_bias=logit_bias, + ), + ) + assert openai_engine._remote_params.num_workers == 50 + assert openai_engine._remote_params.politeness_policy == 60.0 + + conversation = Conversation( + messages=[ + Message(content="Hello", role=Role.USER), + ] + ) + + api_input = openai_engine._convert_conversation_to_api_input( + conversation, openai_engine._generation_params, openai_engine._model_params + ) + + assert api_input["model"] == model_name + assert api_input["temperature"] == expected_temperature + if expected_logit_bias: + assert api_input["logit_bias"] == expected_logit_bias + else: + assert "logit_bias" not in api_input From d009c2d6978e8bb345e12198be02466a86a1332e Mon Sep 17 00:00:00 2001 From: William Zeng <10782997+wizeng23@users.noreply.github.com> Date: Fri, 11 Apr 2025 13:25:10 -0700 Subject: [PATCH 10/15] Update trl to 0.16 (#1631) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ee3b3bf00..298bcc6e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ dependencies = [ # enabled. See OPE-875 and https://github.com/huggingface/transformers/issues/36040. "transformers>=4.51.0,<4.52", # >=0.14.0 is needed for GRPOTrainer. - "trl>=0.15.0,<0.16", + "trl>=0.16.0,<0.17", "typer", # Used by CLI "typing_extensions", # Backports of typing updates to python 3.9 "wandb>=0.19.3,<0.20", # Logging to Weights and Biases. From 159183e2a9ba5bf16e38aa6c69f3c8bebbcab3cb Mon Sep 17 00:00:00 2001 From: nikg4 Date: Mon, 14 Apr 2025 18:32:34 -0700 Subject: [PATCH 11/15] Support custom `processor args` in `ModelParams` (#1634) --- src/oumi/builders/processors.py | 23 +++++-- src/oumi/core/configs/params/model_params.py | 23 ++++++- .../core/evaluation/backends/lm_harness.py | 1 + ...language_conversation_feature_generator.py | 1 + .../inference/native_text_inference_engine.py | 1 + src/oumi/inference/sglang_inference_engine.py | 1 + src/oumi/train.py | 1 + tests/unit/builders/test_processors.py | 18 ++++- .../core/configs/params/test_model_params.py | 68 ++++++++++++++----- .../evaluation/test_backend_lm_harness.py | 1 + 10 files changed, 112 insertions(+), 26 deletions(-) diff --git a/src/oumi/builders/processors.py b/src/oumi/builders/processors.py index c68f1ec1b..661b582e6 100644 --- a/src/oumi/builders/processors.py +++ b/src/oumi/builders/processors.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -from typing import Optional +from typing import Any, Optional import transformers @@ -27,13 +27,20 @@ def build_processor( - processor_name: str, tokenizer: BaseTokenizer, *, trust_remote_code: bool = False + processor_name: str, + tokenizer: BaseTokenizer, + *, + processor_kwargs: Optional[dict[str, Any]] = None, + trust_remote_code: bool = False, ) -> BaseProcessor: """Builds a processor. Args: processor_name: A name of the processor (usually, equals to a model name). tokenizer: A tokenizer to use with the processor. + processor_kwargs: A dictionary of processor-specific parameters. + These parameters are passed to the processor constructor. + They can override model-specific parameters. trust_remote_code: Whether to allow loading remote code for this processor Some processors come with downloadable executable Python files, which can be a potential security risk, unless it's from a trusted source. @@ -51,19 +58,23 @@ def build_processor( # Initialize model-specific params. label_ignore_index: Optional[int] = constants.LABEL_IGNORE_INDEX ignore_features: Optional[list[str]] = None - processor_kwargs = {} + effective_processor_kwargs = {} if model_config is not None: label_ignore_index = model_config.label_ignore_index ignore_features = model_config.ignore_features - processor_kwargs.update(model_config.processor_kwargs) + effective_processor_kwargs.update(model_config.processor_kwargs) + + if processor_kwargs is not None and len(processor_kwargs) > 0: + # Override model-specific params with user-defined ones. + effective_processor_kwargs.update(processor_kwargs) create_processor_fn = functools.partial( transformers.AutoProcessor.from_pretrained, processor_name, trust_remote_code=trust_remote_code, ) - if len(processor_kwargs) > 0: - worker_processor = create_processor_fn(**processor_kwargs) + if len(effective_processor_kwargs) > 0: + worker_processor = create_processor_fn(**effective_processor_kwargs) else: worker_processor = create_processor_fn() diff --git a/src/oumi/core/configs/params/model_params.py b/src/oumi/core/configs/params/model_params.py index 5522918bc..d8a67c79f 100644 --- a/src/oumi/core/configs/params/model_params.py +++ b/src/oumi/core/configs/params/model_params.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from pathlib import Path from typing import Any, Optional @@ -69,6 +69,16 @@ class ModelParams(BaseParams): covered by other fields in ModelParams. """ + processor_kwargs: dict[str, Any] = field(default_factory=dict) + """Additional keyword arguments to pass into the processor's constructor. + + Processors are used in Oumi for vision-language models to process image and + text inputs. This field is optional and can be left empty for text-only models, + or if not needed. + + These params override model-specific default values for these kwargs, if present. + """ + model_max_length: Optional[int] = None """The maximum sequence length the model can handle. @@ -190,6 +200,17 @@ def __post_init__(self): """Populate additional params.""" self.torch_dtype = get_torch_dtype(self.torch_dtype_str) + if len(self.processor_kwargs) > 0: + conflicting_keys = {f.name for f in fields(self)}.intersection( + self.processor_kwargs.keys() + ) + if len(conflicting_keys) > 0: + raise ValueError( + "processor_kwargs attempts to override the following " + f"reserved fields: {conflicting_keys}. " + "Use properties of ModelParams instead." + ) + def __finalize_and_validate__(self): """Finalizes and validates final config params.""" # If the user didn't specify a LoRA adapter, check to see if the dir/repo diff --git a/src/oumi/core/evaluation/backends/lm_harness.py b/src/oumi/core/evaluation/backends/lm_harness.py index f6d91b4b1..fcba1abcf 100644 --- a/src/oumi/core/evaluation/backends/lm_harness.py +++ b/src/oumi/core/evaluation/backends/lm_harness.py @@ -169,6 +169,7 @@ def _generate_lm_harness_model_args( model_params.model_name, tokenizer, trust_remote_code=model_params.trust_remote_code, + processor_kwargs=model_params.processor_kwargs, ) if image_token := processor.image_token: model_args_dict["image_string"] = image_token diff --git a/src/oumi/core/feature_generators/vision_language_conversation_feature_generator.py b/src/oumi/core/feature_generators/vision_language_conversation_feature_generator.py index 70dc19367..0b8e8d90f 100644 --- a/src/oumi/core/feature_generators/vision_language_conversation_feature_generator.py +++ b/src/oumi/core/feature_generators/vision_language_conversation_feature_generator.py @@ -109,6 +109,7 @@ def __init__( f"Ignoring processor_name: {processor_name}" ) elif processor_name: + # TODO OPE-1185 Add plumbing for processor_kwargs processor = build_processor( processor_name, tokenizer, trust_remote_code=trust_remote_code ) diff --git a/src/oumi/inference/native_text_inference_engine.py b/src/oumi/inference/native_text_inference_engine.py index 3dfae761a..27a2f4af1 100644 --- a/src/oumi/inference/native_text_inference_engine.py +++ b/src/oumi/inference/native_text_inference_engine.py @@ -72,6 +72,7 @@ def __init__( self._model_params.model_name, self._tokenizer, trust_remote_code=self._model_params.trust_remote_code, + processor_kwargs=self._model_params.processor_kwargs, ) internal_model_config = find_internal_model_config_using_model_name( self._model_params.model_name, diff --git a/src/oumi/inference/sglang_inference_engine.py b/src/oumi/inference/sglang_inference_engine.py index 8c21e364c..e31992ff6 100644 --- a/src/oumi/inference/sglang_inference_engine.py +++ b/src/oumi/inference/sglang_inference_engine.py @@ -108,6 +108,7 @@ def __init__( self._model_params.model_name, self._tokenizer, trust_remote_code=self._model_params.trust_remote_code, + processor_kwargs=self._model_params.processor_kwargs, ) internal_model_config = find_internal_model_config_using_model_name( self._model_params.model_name, diff --git a/src/oumi/train.py b/src/oumi/train.py index 4c34fffe6..1cdad3d9f 100644 --- a/src/oumi/train.py +++ b/src/oumi/train.py @@ -273,6 +273,7 @@ def train( config.model.model_name, tokenizer, trust_remote_code=config.model.trust_remote_code, + processor_kwargs=config.model.processor_kwargs, ) use_peft = config.training.use_peft and config.peft diff --git a/tests/unit/builders/test_processors.py b/tests/unit/builders/test_processors.py index e77edaf83..75fcad1b5 100644 --- a/tests/unit/builders/test_processors.py +++ b/tests/unit/builders/test_processors.py @@ -1,5 +1,5 @@ import base64 -from typing import Final +from typing import Any, Final, Optional import numpy as np import PIL.Image @@ -39,12 +39,24 @@ def test_build_processor_empty_name(trust_remote_code, mock_tokenizer): build_processor("", mock_tokenizer, trust_remote_code=trust_remote_code) -def test_build_processor_basic_gpt2_success(mock_tokenizer): +@pytest.mark.parametrize( + "processor_kwargs", + [ + None, + {}, + ], +) +def test_build_processor_basic_gpt2_success( + processor_kwargs: Optional[dict[str, Any]], mock_tokenizer +): test_chat_template: Final[str] = build_chat_template(template_name="default") model_params = ModelParams(model_name="openai-community/gpt2") processor = build_processor( - model_params.model_name, mock_tokenizer, trust_remote_code=False + model_params.model_name, + mock_tokenizer, + trust_remote_code=False, + processor_kwargs=processor_kwargs, ) assert callable(processor) diff --git a/tests/unit/core/configs/params/test_model_params.py b/tests/unit/core/configs/params/test_model_params.py index 5ca0f4572..fea07ff74 100644 --- a/tests/unit/core/configs/params/test_model_params.py +++ b/tests/unit/core/configs/params/test_model_params.py @@ -1,3 +1,5 @@ +import dataclasses +from pathlib import Path from unittest.mock import call, patch import pytest @@ -13,56 +15,56 @@ def test_post_init_adapter_model_present(): assert params.adapter_model == "adapter_model" -def test_post_init_adapter_model_not_present(tmp_path): +def test_post_init_adapter_model_not_present(tmp_path: Path): # This is the expected config for FFT. - params = ModelParams(model_name=tmp_path) + params = ModelParams(model_name=str(tmp_path)) params.finalize_and_validate() - assert params.model_name == tmp_path + assert Path(params.model_name) == tmp_path assert params.adapter_model is None @patch("oumi.core.configs.params.model_params.find_adapter_config_file") def test_post_init_adapter_model_not_present_exception( - mock_find_adapter_config_file, tmp_path + mock_find_adapter_config_file, tmp_path: Path ): # This is the expected config for FFT. mock_find_adapter_config_file.side_effect = OSError("No adapter config found.") - params = ModelParams(model_name=tmp_path) + params = ModelParams(model_name=str(tmp_path)) params.finalize_and_validate() - assert params.model_name == tmp_path + assert Path(params.model_name) == tmp_path assert params.adapter_model is None - mock_find_adapter_config_file.assert_called_with(tmp_path) + mock_find_adapter_config_file.assert_called_with(str(tmp_path)) @patch("oumi.core.configs.params.model_params.logger") -def test_post_init_config_file_present(mock_logger, tmp_path): +def test_post_init_config_file_present(mock_logger, tmp_path: Path): with open(f"{tmp_path}/config.json", "w"): pass with open(f"{tmp_path}/adapter_config.json", "w"): pass - params = ModelParams(model_name=tmp_path) + params = ModelParams(model_name=str(tmp_path)) params.finalize_and_validate() - assert params.model_name == tmp_path - assert params.adapter_model == tmp_path + assert Path(params.model_name) == tmp_path + assert Path(params.adapter_model or "") == tmp_path mock_logger.info.assert_called_with( f"Found LoRA adapter at {tmp_path}, setting `adapter_model` to `model_name`." ) @patch("oumi.core.configs.params.model_params.logger") -def test_post_init_config_file_not_present(mock_logger, tmp_path): +def test_post_init_config_file_not_present(mock_logger, tmp_path: Path): with open(f"{tmp_path}/adapter_config.json", "w") as f: f.write('{"base_model_name_or_path": "base_model"}') - params = ModelParams(model_name=tmp_path) + params = ModelParams(model_name=str(tmp_path)) params.finalize_and_validate() assert params.model_name == "base_model" - assert params.adapter_model == tmp_path + assert Path(params.adapter_model or "") == tmp_path assert mock_logger.info.call_count == 2 mock_logger.info.assert_has_calls( @@ -77,14 +79,48 @@ def test_post_init_config_file_not_present(mock_logger, tmp_path): @patch("oumi.core.configs.params.model_params.logger") -def test_post_init_config_file_empty(mock_logger, tmp_path): +def test_post_init_config_file_empty(mock_logger, tmp_path: Path): with open(f"{tmp_path}/adapter_config.json", "w") as f: f.write("{}") - params = ModelParams(model_name=tmp_path) + params = ModelParams(model_name=str(tmp_path)) with pytest.raises( ValueError, match="`model_name` specifies an adapter model only," " but the base model could not be found!", ): params.finalize_and_validate() + + +def _get_invalid_field_name_lists() -> list[list[str]]: + all_fields: set[str] = {f.name for f in dataclasses.fields(ModelParams())} + result = [[field_name] for field_name in all_fields] + result.extend([["valid_kwarg", field_name] for field_name in all_fields][:3]) + return result + + +def _get_test_name_for_invalid_field_name_list(x): + assert isinstance(x, list) + return "--".join(x) + + +@pytest.mark.parametrize( + "field_names", + _get_invalid_field_name_lists(), + ids=_get_test_name_for_invalid_field_name_list, +) +def test_model_params_reserved_processor_kwargs(field_names: list[str], tmp_path: Path): + invalid_names = {f.name for f in dataclasses.fields(ModelParams())}.intersection( + field_names + ) + with pytest.raises( + ValueError, + match=( + "processor_kwargs attempts to override the following reserved fields: " + f"{invalid_names}" + ), + ): + ModelParams( + model_name=str(tmp_path), + processor_kwargs={field_name: "foo_value" for field_name in field_names}, + ) diff --git a/tests/unit/core/evaluation/test_backend_lm_harness.py b/tests/unit/core/evaluation/test_backend_lm_harness.py index 47931b564..988e9182a 100644 --- a/tests/unit/core/evaluation/test_backend_lm_harness.py +++ b/tests/unit/core/evaluation/test_backend_lm_harness.py @@ -223,6 +223,7 @@ def test_generate_lm_harness_model_args( model_params.model_name, mock_build_tokenizer.return_value, trust_remote_code=model_params.trust_remote_code, + processor_kwargs={}, ) else: mock_build_tokenizer.assert_not_called() From b92bcb2a2230a5d7165793dc3bf266452604f6cf Mon Sep 17 00:00:00 2001 From: William Zeng <10782997+wizeng23@users.noreply.github.com> Date: Tue, 15 Apr 2025 10:59:18 -0700 Subject: [PATCH 12/15] Support BerryBench evaluation (#1635) --- .../examples/berry_bench/evaluation/eval.yaml | 36 ++++++++ .../berry_bench/evaluation/gcp_job.yaml | 60 +++++++++++++ .../letter_counting/evaluation/gcp_job.yaml | 1 + src/oumi/cli/evaluate.py | 10 ++- src/oumi/datasets/grpo/__init__.py | 2 + src/oumi/datasets/grpo/berry_bench.py | 81 +++++++++++++++++ src/oumi/evaluation/registry/__init__.py | 2 + .../evaluation/registry/berry_bench_task.py | 88 +++++++++++++++++++ .../evaluation/registry/count_letters_task.py | 5 +- 9 files changed, 281 insertions(+), 4 deletions(-) create mode 100644 configs/examples/berry_bench/evaluation/eval.yaml create mode 100644 configs/examples/berry_bench/evaluation/gcp_job.yaml create mode 100644 src/oumi/datasets/grpo/berry_bench.py create mode 100644 src/oumi/evaluation/registry/berry_bench_task.py diff --git a/configs/examples/berry_bench/evaluation/eval.yaml b/configs/examples/berry_bench/evaluation/eval.yaml new file mode 100644 index 000000000..729ed74c3 --- /dev/null +++ b/configs/examples/berry_bench/evaluation/eval.yaml @@ -0,0 +1,36 @@ +# Config to evaluate an LLM on the Oumi BerryBench dataset. +# +# Requirements: +# - Run `pip install oumi[gpu]` +# - Log into HF: `huggingface-cli login` +# - Request access to Llama 3.2: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct +# +# Usage: +# oumi evaluate -c oumi://configs/examples/berry_bench/evaluation/eval.yaml +# +# See Also: +# - Documentation: https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html +# - Config class: oumi.core.configs.EvaluationConfig +# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/evaluation_config.py +# - Other eval configs: configs/**/evaluation/ + +model: + model_name: "meta-llama/Llama-3.2-3B-Instruct" + model_max_length: 131072 + torch_dtype_str: "bfloat16" + attn_implementation: "sdpa" + trust_remote_code: True + +generation: + max_new_tokens: 2048 + # This isn't used by vLLM, but is used for the NATIVE inference engine. + batch_size: 4 + +tasks: + - evaluation_backend: custom + task_name: berry_bench + num_samples: 1000 + +inference_engine: VLLM # Can also use NATIVE if not running on GPUs + +output_dir: "output/berry_bench/evaluation" diff --git a/configs/examples/berry_bench/evaluation/gcp_job.yaml b/configs/examples/berry_bench/evaluation/gcp_job.yaml new file mode 100644 index 000000000..9ffc42411 --- /dev/null +++ b/configs/examples/berry_bench/evaluation/gcp_job.yaml @@ -0,0 +1,60 @@ +# Job config to evaluate an LLM on the Oumi BerryBench dataset. +# +# Requirements: +# - Set up SkyPilot GCP: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html#setup +# - Log into HF: `huggingface-cli login` +# - Request access to Llama 3.2: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct +# +# Usage: +# oumi launch up -c oumi://configs/examples/berry_bench/evaluation/gcp_job.yaml --cluster berry-bench-eval +# +# See Also: +# - Documentation: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html +# - Config class: oumi.core.configs.JobConfig +# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/job_config.py +# - Other job configs: configs/**/*job.yaml + +name: berry-bench-eval + +resources: + cloud: gcp + accelerators: "A100" + use_spot: false + +working_dir: . + +file_mounts: + ~/.netrc: ~/.netrc # WandB credentials + ~/.cache/huggingface/token: ~/.cache/huggingface/token # HF credentials + +envs: + # NOTE: For SFT, update this to point to your model checkpoint. + # NOTE: For LoRA, instead update this to point to your LoRA adapter. + # The base model will be inferred automatically. + MODEL_CHECKPOINT_DIR: meta-llama/Llama-3.2-3B-Instruct + WANDB_PROJECT: oumi-eval + OUMI_RUN_NAME: berry-bench.eval + +setup: | + set -e + pip install uv && uv pip install oumi[gpu] + pip install "vllm>=0.7.3,<0.8.0" + +run: | + set -e # Exit if any command failed. + source ./configs/examples/misc/sky_init.sh + + if test ${OUMI_NUM_NODES} -ne 1; then + echo "LM Harness supports max 1 node. Actual: ${OUMI_NUM_NODES} nodes." + exit 1 + fi + + echo "Starting evaluation for ${MODEL_CHECKPOINT_DIR} ..." + set -x + + oumi evaluate \ + -c oumi://configs/examples/berry_bench/evaluation/eval.yaml \ + --run_name "${OUMI_RUN_NAME}.${SKYPILOT_TASK_ID}" \ + --model.model_name "${MODEL_CHECKPOINT_DIR}" + + echo "Node ${SKYPILOT_NODE_RANK} is all done!" diff --git a/configs/examples/letter_counting/evaluation/gcp_job.yaml b/configs/examples/letter_counting/evaluation/gcp_job.yaml index 6707e8689..44d076cfb 100644 --- a/configs/examples/letter_counting/evaluation/gcp_job.yaml +++ b/configs/examples/letter_counting/evaluation/gcp_job.yaml @@ -38,6 +38,7 @@ envs: setup: | set -e pip install uv && uv pip install oumi[gpu] + pip install "vllm>=0.7.3,<0.8.0" run: | set -e # Exit if any command failed. diff --git a/src/oumi/cli/evaluate.py b/src/oumi/cli/evaluate.py index e0f5c8fbf..fb2fdfc06 100644 --- a/src/oumi/cli/evaluate.py +++ b/src/oumi/cli/evaluate.py @@ -114,10 +114,18 @@ def evaluate( # Clean up metric name clean_metric = base_name.replace("_", " ").title() + if isinstance(value, float): + if value > 1: + value_str = f"{value:.2f}" + else: + value_str = f"{value:.2%}" + else: + # Includes ints + value_str = str(value) table.add_row( benchmark_name, clean_metric, - f"{value:.2%}" if value <= 1 else f"{value:.2f}", + value_str, stderr_display, ) cli_utils.CONSOLE.print(table) diff --git a/src/oumi/datasets/grpo/__init__.py b/src/oumi/datasets/grpo/__init__.py index 7fead73b3..10cf27310 100644 --- a/src/oumi/datasets/grpo/__init__.py +++ b/src/oumi/datasets/grpo/__init__.py @@ -14,10 +14,12 @@ """GRPO datasets module.""" +from oumi.datasets.grpo.berry_bench import BerryBenchGrpoDataset from oumi.datasets.grpo.letter_count import LetterCountGrpoDataset from oumi.datasets.grpo.tldr import TldrGrpoDataset __all__ = [ + "BerryBenchGrpoDataset", "LetterCountGrpoDataset", "TldrGrpoDataset", ] diff --git a/src/oumi/datasets/grpo/berry_bench.py b/src/oumi/datasets/grpo/berry_bench.py new file mode 100644 index 000000000..128c33888 --- /dev/null +++ b/src/oumi/datasets/grpo/berry_bench.py @@ -0,0 +1,81 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from typing_extensions import override + +from oumi.core.datasets.base_grpo_dataset import BaseExperimentalGrpoDataset +from oumi.core.registry import register_dataset +from oumi.core.types.conversation import Conversation + +_SYSTEM_PROMPT = ( + "Your final answer should be formatted as " + r'```json YOUR_ANSWER```. For example, if the answer is {"a": 1}, ' + r'you should output ```json {"a": 1}}```.' +) + + +@register_dataset("oumi-ai/berrybench-v0.1.1") +class BerryBenchGrpoDataset(BaseExperimentalGrpoDataset): + r"""Dataset class for the `oumi-ai/berrybench-v0.1.1` dataset. + + A sample from the dataset: + { + "messages": [ + { + "content": "Return a JSON object showing the frequency of each character in the word '黒い'. Only include characters that appear in the word.", + "role": "user", + } + ], + "metadata": { + "character_count": 2, + "difficulty": 3, + "expected_response": '{"\\u9ed2": 1, "\\u3044": 1}', + "language": "japanese", + "word": "黒い", + }, + } + """ # noqa: E501 + + default_dataset = "oumi-ai/berrybench-v0.1.1" + + @override + def transform(self, sample: pd.Series) -> dict: + """Transform the sample into Python `dict`.""" + # Add system prompt before user prompt. + system_message = {"content": _SYSTEM_PROMPT, "role": "system"} + messages = [system_message, sample["messages"][0]] + return { + "prompt": messages, + "metadata": sample["metadata"], + } + + @override + def transform_conversation(self, sample: pd.Series) -> Conversation: + """Converts the input sample to a Conversation. + + Args: + sample (dict): The input example. + + Returns: + Conversation: The resulting conversation. + + """ + # Example is already in conversation format and only needs light processing. + sample_dict = sample.to_dict() + # Add system prompt before user prompt. + system_message = {"content": _SYSTEM_PROMPT, "role": "system"} + messages = [system_message, sample["messages"][0]] + sample_dict["messages"] = messages + return Conversation.from_dict(sample_dict) diff --git a/src/oumi/evaluation/registry/__init__.py b/src/oumi/evaluation/registry/__init__.py index 38da75070..dd725f76c 100644 --- a/src/oumi/evaluation/registry/__init__.py +++ b/src/oumi/evaluation/registry/__init__.py @@ -14,8 +14,10 @@ """Evaluation registry module.""" +from oumi.evaluation.registry.berry_bench_task import berry_bench from oumi.evaluation.registry.count_letters_task import count_letters __all__ = [ + "berry_bench", "count_letters", ] diff --git a/src/oumi/evaluation/registry/berry_bench_task.py b/src/oumi/evaluation/registry/berry_bench_task.py new file mode 100644 index 000000000..40ca0c9c2 --- /dev/null +++ b/src/oumi/evaluation/registry/berry_bench_task.py @@ -0,0 +1,88 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from typing import Any, Optional + +from oumi.core.configs.params.evaluation_params import EvaluationTaskParams +from oumi.core.inference.base_inference_engine import BaseInferenceEngine +from oumi.core.registry import register_evaluation_function +from oumi.datasets.grpo.berry_bench import BerryBenchGrpoDataset +from oumi.utils.logging import logger + + +def _extract_json(response: str) -> Optional[dict]: + r"""Returns the json answer extracted from ```json ...```, or None otherwise.""" + logger.info(f"response: {response}") + # re.DOTALL lets '.' match newlines. Most LLMs use newlines in their JSON outputs. + regex_result = re.findall("```json(.*)```", response, re.DOTALL) + logger.info(f"result: {regex_result}") + if not regex_result or len(regex_result) != 1: + return None + json_str = regex_result[0] + try: + return json.loads(json_str) + except json.decoder.JSONDecodeError: + return None + + +@register_evaluation_function("berry_bench") +def berry_bench( + task_params: EvaluationTaskParams, + inference_engine: BaseInferenceEngine, +) -> dict[str, Any]: + """Custom evaluation function registered as `berry_bench`.""" + dataset = BerryBenchGrpoDataset(split="test") + num_samples = task_params.num_samples + if num_samples is None: + num_samples = len(dataset) + input_conversations = [dataset.conversation(i) for i in range(num_samples)] + conversations = inference_engine.infer(input_conversations) + logger.info(f"Finished inference on {len(conversations)} conversations!") + if len(conversations) > 0: + logger.info(f"Sample conversation: {conversations[0]}") + + count = 0 # The number of examples with correct answers extracted. + total = 0 # All examples. + valid_count = 0 # The number of examples with valid answers extracted. + for i, conversation in enumerate(conversations): + total += 1 + # Grab the model's response + response = conversation.last_message() + # Ignore cases where model didn't respond or it's a multimodal response. + # For now, we focus on text-only responses. + if not response or not isinstance(response.content, str): + continue + # Count the example as correct if the extracted prediction is correct. + prediction = _extract_json(response.content) + if prediction is None: + continue + valid_count += 1 + expected_json_str = conversation.metadata["expected_response"] + expected_json = json.loads(expected_json_str) + if prediction == expected_json: + count += 1 + + return { + # Accuracy across all examples. + "accuracy": count / total if total > 0 else 0, + # Accuracy when only counting examples with properly extracted answers. + "properly_extracted_accuracy": count / valid_count if valid_count > 0 else 0, + "num_samples": num_samples, + # These three values sum up to num_samples. + "num_correct_answers": count, + "num_incorrect_answers": valid_count - count, + "num_invalid_answers": total - valid_count, + } diff --git a/src/oumi/evaluation/registry/count_letters_task.py b/src/oumi/evaluation/registry/count_letters_task.py index 01e48d0d3..aec13b735 100644 --- a/src/oumi/evaluation/registry/count_letters_task.py +++ b/src/oumi/evaluation/registry/count_letters_task.py @@ -44,7 +44,6 @@ def count_letters( dataset = LetterCountGrpoDataset(split="test") # TODO: OPE-1155: Add support for using Oumi dataset code to create the dataset. # dataset = build_dataset("oumi-ai/oumi-letter-count", tokenizer=None, sample_count=10) # noqa: E501 - # dataset = build_dataset("oumi-ai/berrybench-v0.1.0", tokenizer=None, sample_count=10) # noqa: E501 num_samples = task_params.num_samples if num_samples is None: num_samples = len(dataset) @@ -75,9 +74,9 @@ def count_letters( return { # Accuracy across all examples. - "accuracy": count / total, + "accuracy": count / total if total > 0 else 0, # Accuracy when only counting examples with properly extracted answers. - "properly_extracted_accuracy": count / valid_count, + "properly_extracted_accuracy": count / valid_count if valid_count > 0 else 0, "num_samples": num_samples, # These three values sum up to num_samples. "num_correct_answers": count, From 50b8b09ae699efacc56d13215206ccb2632e9354 Mon Sep 17 00:00:00 2001 From: Kostas Date: Wed, 16 Apr 2025 09:36:59 -0700 Subject: [PATCH 13/15] [Remote Inference] Error checking for `api_key` (#1638) --- src/oumi/inference/remote_inference_engine.py | 13 +++++++- .../inference/test_remote_inference_engine.py | 32 ++++++++++++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/oumi/inference/remote_inference_engine.py b/src/oumi/inference/remote_inference_engine.py index 9b037e966..e7bd9c9d2 100644 --- a/src/oumi/inference/remote_inference_engine.py +++ b/src/oumi/inference/remote_inference_engine.py @@ -382,7 +382,10 @@ def _get_request_headers( if not remote_params: return headers - headers[_AUTHORIZATION_KEY] = f"Bearer {self._get_api_key(remote_params)}" + api_key = self._get_api_key(remote_params) + if api_key: + headers[_AUTHORIZATION_KEY] = f"Bearer {api_key}" + return headers def _set_required_fields_for_inference(self, remote_params: RemoteParams): @@ -428,6 +431,14 @@ async def _query_api( self._set_required_fields_for_inference(remote_params) if not remote_params.api_url: raise ValueError("API URL is required for remote inference.") + if not self._get_api_key(remote_params): + if remote_params.api_key_env_varname: + raise ValueError( + "An API key is required for remote inference with the " + f"`{self.__class__.__name__}` inference engine. " + "Please set the environment variable " + f"`{remote_params.api_key_env_varname}`." + ) async with semaphore: api_input = self._convert_conversation_to_api_input( conversation, generation_params, model_params diff --git a/tests/unit/inference/test_remote_inference_engine.py b/tests/unit/inference/test_remote_inference_engine.py index d86cf7957..95309974a 100644 --- a/tests/unit/inference/test_remote_inference_engine.py +++ b/tests/unit/inference/test_remote_inference_engine.py @@ -528,6 +528,27 @@ def test_infer_no_remote_params_api_url(): ) +def test_infer_no_api_key(): + with pytest.raises( + ValueError, + match=( + r"An API key is required for remote inference with the " + r"`RemoteInferenceEngine` inference engine. Please set the environment " + r"variable `MY_API_KEY`." + ), + ): + engine = RemoteInferenceEngine( + model_params=_get_default_model_params(), + remote_params=RemoteParams( + api_url=_TARGET_SERVER, + api_key_env_varname="MY_API_KEY", # Indicates that API key is required. + ), + ) + engine.infer( + input=[Conversation(messages=[])], + ) + + def test_infer_online_empty(): engine = RemoteInferenceEngine( _get_default_model_params(), remote_params=RemoteParams(api_url=_TARGET_SERVER) @@ -1614,6 +1635,15 @@ def test_get_request_headers_with_api_key(): assert headers == {"Authorization": "Bearer test-key"} +def test_get_request_headers_without_api_key(): + remote_params = RemoteParams(api_url=_TARGET_SERVER) + engine = RemoteInferenceEngine( + _get_default_model_params(), remote_params=remote_params + ) + headers = engine._get_request_headers(remote_params) + assert headers == {} + + def test_get_request_headers_with_env_var(): with patch.dict(os.environ, {"OPENAI_API_KEY": "env-test-key"}): remote_params = RemoteParams( @@ -1637,7 +1667,7 @@ def test_get_request_headers_missing_env_var(): remote_params=remote_params, ) headers = engine._get_request_headers(remote_params) - assert headers == {"Authorization": "Bearer None"} + assert headers == {} @pytest.mark.asyncio From 3689714a0ee1d830dcc08dfec0affc14cb887ca9 Mon Sep 17 00:00:00 2001 From: William Zeng <10782997+wizeng23@users.noreply.github.com> Date: Wed, 16 Apr 2025 10:33:49 -0700 Subject: [PATCH 14/15] Rename cnn_mnist_example to cnn_mnist_tutorial (#1640) --- notebooks/Oumi - Training CNN on Custom Dataset.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notebooks/Oumi - Training CNN on Custom Dataset.ipynb b/notebooks/Oumi - Training CNN on Custom Dataset.ipynb index dbf0f4c6a..aa174e9ef 100644 --- a/notebooks/Oumi - Training CNN on Custom Dataset.ipynb +++ b/notebooks/Oumi - Training CNN on Custom Dataset.ipynb @@ -83,7 +83,7 @@ "import numpy as np\n", "import torchvision\n", "\n", - "tutorial_dir = \"cnn_mnist_example\"\n", + "tutorial_dir = \"cnn_mnist_tutorial\"\n", "\n", "Path(tutorial_dir).mkdir(parents=True, exist_ok=True)\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\" # Disable warnings from HF" @@ -111,7 +111,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Saved 70000 examples to '/home/user/oumi/notebooks/cnn_mnist_example/mnist.npz'!\n" + "Saved 70000 examples to '/home/user/oumi/notebooks/cnn_mnist_tutorial/mnist.npz'!\n" ] } ], From e92daf4b0a568384ce80bc3c08d498984e269576 Mon Sep 17 00:00:00 2001 From: Kostas Date: Wed, 16 Apr 2025 10:46:48 -0700 Subject: [PATCH 15/15] [Remote Inference][GCP] Constructing `api_url` from the Project ID and Region (#1636) Co-authored-by: Matthew Persons --- src/oumi/inference/gcp_inference_engine.py | 93 +++++++++++++++++++ .../inference/test_gcp_inference_engine.py | 85 +++++++++++++++++ 2 files changed, 178 insertions(+) diff --git a/src/oumi/inference/gcp_inference_engine.py b/src/oumi/inference/gcp_inference_engine.py index 9d7ce23d7..6bfa3b066 100644 --- a/src/oumi/inference/gcp_inference_engine.py +++ b/src/oumi/inference/gcp_inference_engine.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import os from typing import Any, Optional import pydantic @@ -27,6 +28,98 @@ class GoogleVertexInferenceEngine(RemoteInferenceEngine): """Engine for running inference against Google Vertex AI.""" + _API_URL_TEMPLATE = ( + "https://{region}-aiplatform.googleapis.com/v1beta1/projects/" + "{project_id}/locations/{region}/endpoints/openapi/chat/completions" + ) + """The API URL template for the GCP project. Used when no `api_url` is provided.""" + + _DEFAULT_PROJECT_ID_ENV_KEY: str = "PROJECT_ID" + """The default project ID environment key for the GCP project.""" + + _DEFAULT_REGION_ENV_KEY: str = "REGION" + """The default region environment key for the GCP project.""" + + _project_id: Optional[str] = None + """The project ID for the GCP project.""" + + _region: Optional[str] = None + """The region for the GCP project.""" + + def __init__( + self, + model_params: ModelParams, + *, + generation_params: Optional[GenerationParams] = None, + remote_params: Optional[RemoteParams] = None, + project_id_env_key: Optional[str] = None, + region_env_key: Optional[str] = None, + project_id: Optional[str] = None, + region: Optional[str] = None, + ): + """Initializes the inference Engine. + + Args: + model_params: The model parameters to use for inference. + generation_params: The generation parameters to use for inference. + remote_params: The remote parameters to use for inference. + project_id_env_key: The environment variable key name for the project ID. + region_env_key: The environment variable key name for the region. + project_id: The project ID to use for inference. + region: The region to use for inference. + """ + super().__init__( + model_params=model_params, + generation_params=generation_params, + remote_params=remote_params, + ) + if project_id and project_id_env_key: + raise ValueError( + "You cannot set both `project_id` and `project_id_env_key`." + ) + if region and region_env_key: + raise ValueError("You cannot set both `region` and `region_env_key`.") + + self._project_id_env_key = ( + project_id_env_key or self._DEFAULT_PROJECT_ID_ENV_KEY + ) + self._region_env_key = region_env_key or self._DEFAULT_REGION_ENV_KEY + self._project_id = project_id + self._region = region + + @override + def _set_required_fields_for_inference(self, remote_params: RemoteParams) -> None: + """Set required fields for inference.""" + if ( + not remote_params.api_url + and not self._remote_params.api_url + and not self.base_url + ): + if self._project_id and self._region: + project_id = self._project_id + region = self._region + elif os.getenv(self._project_id_env_key) and os.getenv( + self._region_env_key + ): + project_id = os.getenv(self._project_id_env_key) + region = os.getenv(self._region_env_key) + else: + raise ValueError( + "This inference engine requires that either `api_url` is set in " + "`RemoteParams` or that both `project_id` and `region` are set. " + "You can set the `project_id` and `region` when " + "constructing a GoogleVertexInferenceEngine, " + f"or as environment variables: `{self._project_id_env_key}` and " + f"`{self._region_env_key}`." + ) + + remote_params.api_url = self._API_URL_TEMPLATE.format( + project_id=project_id, + region=region, + ) + + super()._set_required_fields_for_inference(remote_params) + @override def _get_api_key(self, remote_params: RemoteParams) -> str: """Gets the authentication token for GCP.""" diff --git a/tests/unit/inference/test_gcp_inference_engine.py b/tests/unit/inference/test_gcp_inference_engine.py index 8bddf10c7..5857b9154 100644 --- a/tests/unit/inference/test_gcp_inference_engine.py +++ b/tests/unit/inference/test_gcp_inference_engine.py @@ -189,3 +189,88 @@ def test_remote_params_defaults(): ) assert gcp_engine._remote_params.num_workers == 10 assert gcp_engine._remote_params.politeness_policy == 60.0 + + +def test_setting_api_url_via_constructor_region_and_project_id(): + gcp_engine = GoogleVertexInferenceEngine( + model_params=ModelParams(model_name="some_model"), + project_id="test_project_id", + region="test_region", + ) + + # The method `_set_required_fields_for_inference` is called right before querying + # the remote API (with `_query_api`) to validate/update the remote params. + remote_params = RemoteParams() + gcp_engine._set_required_fields_for_inference(remote_params) + + expected_api_url = ( + "https://test_region-aiplatform.googleapis.com/v1beta1/projects/" + "test_project_id/locations/test_region/endpoints/openapi/chat/completions" + ) + assert remote_params.api_url == expected_api_url + + +@patch("os.getenv") +def test_setting_api_url_via_env_region_and_project_id(mock_getenv): + def mock_getenv_fn(key): + return {"PROJECT_ID": "test_project_id", "REGION": "test_region"}.get(key) + + mock_getenv.side_effect = mock_getenv_fn + + gcp_engine = GoogleVertexInferenceEngine( + model_params=ModelParams(model_name="some_model"), + ) + + # The method `_set_required_fields_for_inference` is called right before querying + # the remote API (with `_query_api`) to validate/update the remote params. + remote_params = RemoteParams() + gcp_engine._set_required_fields_for_inference(remote_params) + + expected_api_url = ( + "https://test_region-aiplatform.googleapis.com/v1beta1/projects/" + "test_project_id/locations/test_region/endpoints/openapi/chat/completions" + ) + assert remote_params.api_url == expected_api_url + + +@patch("os.getenv") +def test_setting_api_url_via_env_region_and_project_id_custom_keys(mock_getenv): + def mock_getenv_fn(key): + return { + "CUSTOM_PROJECT_ID_KEY": "test_project_id", + "CUSTOM_REGION_KEY": "test_region", + }.get(key) + + mock_getenv.side_effect = mock_getenv_fn + + gcp_engine = GoogleVertexInferenceEngine( + model_params=ModelParams(model_name="some_model"), + project_id_env_key="CUSTOM_PROJECT_ID_KEY", + region_env_key="CUSTOM_REGION_KEY", + ) + + # The method `_set_required_fields_for_inference` is called right before querying + # the remote API (with `_query_api`) to validate/update the remote params. + remote_params = RemoteParams() + gcp_engine._set_required_fields_for_inference(remote_params) + + expected_api_url = ( + "https://test_region-aiplatform.googleapis.com/v1beta1/projects/" + "test_project_id/locations/test_region/endpoints/openapi/chat/completions" + ) + assert remote_params.api_url == expected_api_url + + +def test_not_setting_api_url_failure(): + gcp_engine = GoogleVertexInferenceEngine( + model_params=ModelParams(model_name="some_model"), + ) + + with pytest.raises(ValueError) as exception_info: + gcp_engine._set_required_fields_for_inference(RemoteParams()) + assert str(exception_info.value) == ( + "This inference engine requires that either `api_url` is set in `RemoteParams` " + "or that both `project_id` and `region` are set. You can set the `project_id` " + "and `region` when constructing a GoogleVertexInferenceEngine, or as " + "environment variables: `PROJECT_ID` and `REGION`." + )