8000 [Cherrypick for launch] Evaluate: return dict of results by kaisopos · Pull Request #1197 · oumi-ai/oumi · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Cherrypick for launch] Evaluate: return dict of results #1197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/oumi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,15 @@ def evaluate_async(config: AsyncEvaluationConfig) -> None:
return oumi.evaluate_async.evaluate_async(config)


def evaluate(config: EvaluationConfig) -> None:
def evaluate(config: EvaluationConfig) -> list[dict[str, Any]]:
"""Evaluates a model using the provided configuration.

Args:
config: The desired configuration for evaluation.

Returns:
None.
A list of evaluation results (one for each task). Each evaluation result is a
dictionary of metric names and their corresponding values.
"""
import oumi.evaluate

Expand Down
15 changes: 11 additions & 4 deletions src/oumi/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from oumi.core.configs import EvaluationConfig
from oumi.core.configs.params.evaluation_params import (
AlpacaEvalTaskParams,
Expand All @@ -9,15 +11,17 @@
from oumi.evaluation.platform_prerequisites import check_prerequisites


def evaluate(config: EvaluationConfig) -> None:
def evaluate(config: EvaluationConfig) -> list[dict[str, Any]]:
"""Evaluates a model using the provided configuration.

Args:
config: The desired configuration for evaluation.

Returns:
None.
A list of evaluation results (one for each task). Each evaluation result is a
dictionary of metric names and their corresponding values.
"""
results = []
for task in config.tasks:
check_prerequisites(
evaluation_platform=task.get_evaluation_platform(),
Expand All @@ -27,22 +31,23 @@ def evaluate(config: EvaluationConfig) -> None:
if task.get_evaluation_platform() == EvaluationPlatform.LM_HARNESS:
lm_harness_task_params = task.get_evaluation_platform_task_params()
assert isinstance(lm_harness_task_params, LMHarnessTaskParams)
evaluate_lm_harness(
result = evaluate_lm_harness(
task_params=lm_harness_task_params,
output_dir=config.output_dir,
model_params=config.model,
generation_params=config.generation,
enable_wandb=config.enable_wandb,
run_name=config.run_name,
)
results.append(result)
elif task.get_evaluation_platform() == EvaluationPlatform.ALPACA_EVAL:
alpaca_eval_task_params = task.get_evaluation_platform_task_params()
assert isinstance(alpaca_eval_task_params, AlpacaEvalTaskParams)
if not config.inference_engine:
raise ValueError(
"Inference engine must be specified for Alpaca Eval evaluation."
)
evaluate_alpaca_eval(
result = evaluate_alpaca_eval(
task_params=alpaca_eval_task_params,
output_dir=config.output_dir,
model_params=config.model,
Expand All @@ -51,5 +56,7 @@ def evaluate(config: EvaluationConfig) -> None:
inference_remote_params=config.inference_remote_params,
run_name=config.run_name,
)
results.append(result)
else:
raise ValueError("Unknown evaluation platform")
return results
10 changes: 9 additions & 1 deletion src/oumi/evaluation/alpaca_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def evaluate(
inference_engine_type: InferenceEngineType,
inference_remote_params: Optional[RemoteParams] = None,
run_name: Optional[str] = None,
) -> None:
) -> dict[str, Any]:
"""Evaluates a model using the Alpaca Eval framework.

For detailed documentation on the AlpacaEval framework, we refer you to the
Expand All @@ -48,6 +48,9 @@ def evaluate(
inference_remote_params: The remote inference parameters to use.
output_dir: The directory where the evaluation results will be saved.
run_name: Unique identifier for the current run.

Returns:
The evaluation results (dict of metric names and their corresponding values).
"""
# Prerequisites
if not alpaca_eval:
Expand Down Expand Up @@ -173,3 +176,8 @@ def evaluate(
generation_params=generation_params,
inference_config=inference_config,
)

if metric_dict:
return {"results": metric_dict}

return {}
41 changes: 24 additions & 17 deletions src/oumi/evaluation/lm_harness.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
import time
from datetime import datetime
Expand Down Expand Up @@ -49,7 +50,7 @@ def evaluate(
generation_params: GenerationParams,
enable_wandb: bool,
run_name: Optional[str] = None,
) -> None:
) -> dict[str, Any]:
"""Evaluates a model using the LM Evaluation Harness framework (EleutherAI).

For detailed documentation, we refer you to the following readme:
Expand All @@ -62,6 +63,9 @@ def evaluate(
output_dir: The directory where the evaluation results will be saved.
enable_wandb: Whether to enable Weights & Biases (wandb) logging.
run_name: Unique identifier for wandb for the current training run.

Returns:
The evaluation results (dict of metric names and their corresponding values).
"""
if torch.cuda.is_available():
# CUDA device may be overwritten if `accelerate launch`,
Expand Down Expand Up @@ -136,26 +140,26 @@ def evaluate(
wandb_logger.post_init(lm_eval_output)
wandb_logger.log_eval_result()

if output_dir:
# The LM Harness platform's task configuration is a dictionary which
# includes: the number of samples, the number of few-shots, task version(s),
# the prompt(s) text, model/git hashes, seeds, and the special tokens used
# by the tokenizer (such as `pad`, `eos`, `bos, and `eot`).
platform_task_config = lm_eval_output

# The LM Harness platform's results is a dictionary that includes all
# evaluation metrics, which are oftentimes grouped (in `groups`) by a theme
# or a classification category.
platform_results = {
key: platform_task_config.pop(key)
for key in ["results", "groups"]
if key in platform_task_config
}
# The LM Harness platform's task configuration is a dictionary which
# includes: the number of samples, the number of few-shots, task version(s),
# the prompt(s) text, model/git hashes, seeds, and the special tokens used
# by the tokenizer (such as `pad`, `eos`, `bos, and `eot`).
platform_task_config = lm_eval_output

# The LM Harness platform's results is a dictionary that includes all
# evaluation metrics, which are oftentimes grouped (in `groups`) by a theme
# or a classification category.
platform_results = {
key: platform_task_config.pop(key)
for key in ["results", "groups"]
if key in platform_task_config
}

if output_dir:
save_evaluation_output(
base_output_dir=output_dir,
platform=task_params.get_evaluation_platform(),
platform_results=platform_results,
platform_results=copy.deepcopy(platform_results),
platform_task_config=platform_task_config,
task_params=task_params,
start_time_str=start_time_str,
Expand All @@ -164,3 +168,6 @@ def evaluate(
generation_params=generation_params,
inference_config=None,
)

return platform_results
return {}
0