-
Notifications
You must be signed in to change notification settings - Fork 2k
🧬 Add generation_kwargs
as a property of GRPOConfig
to support additional generation arguments.
#3617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🧬 Add generation_kwargs
as a property of GRPOConfig
to support additional generation arguments.
#3617
Conversation
trl/trainer/grpo_config.py
Outdated
@@ -85,6 +85,10 @@ class GRPOConfig(TrainingArguments): | |||
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. | |||
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat | |||
tokens. | |||
generation_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): | |||
Additional keyword arguments to pass to the model's `generate` method. If `None`, it defaults to an empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it's passed to GenerationConfig, not generate
. I would clarify that, depending on how you generate (using vLLM or not), this dictionary will populate either SamplingParams
(for vLLM) or GenerationConfig
(for other frameworks).
override the default recommended generation arguments
To be precise, it will also override any generation config passed like min_p
I think there is an issue with top_p. The values are not handled the same in vllm and in transformers. I don't think we can have a shared |
And it seems not to work with vLLM server |
Ok it should be better now. And I also added the training to the test, because the sampling happens during training. Just initializing was a bit lite |
generation_kwargs
as a property of GRPOConfig
to support additional generation arguments.generation_kwargs
as a property of GRPOConfig
to support additional generation arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition, thanks!!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks for those changes! Definitely dropped the ball on testing out with vllm. |
trl/scripts/vllm_serve.py
Outdated
@@ -487,7 +487,8 @@ async def generate(request: GenerateRequest): | |||
"max_tokens": request.max_tokens, | |||
"guided_decoding": guided_decoding, | |||
} | |||
generation_kwargs.update(request.generation_kwargs) | |||
if request.generation_kwargs is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this can be None. If not provided, it should be an empty dict
What does this PR do?
This PR is an attempt to address #3562 to allow users to specify additional arguments that can be passed to either vllms
SamplingParams
or HFsGenerationConfig
to give a user more control over how the old policy generates sample responses.Fixes #3562
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.