8000 Add entropy based filtering inside the GRPOTrainer. by pramodith · Pull Request #3563 · huggingface/trl · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add entropy based filtering inside the GRPOTrainer. #3563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

pramodith
Copy link
Contributor
@pramodith pramodith commented Jun 10, 2025

What does this PR do?

This PR is in relation to #3555 which proposes to mask out the policy loss coming from tokens in the completions corresponding to positions with an entropy scores below the bottom-k percentile.

This idea is proposed by the Qwen team in their accompanying paper Beyond the 80/20 Rule

Key Proposals of the paper that guided the implementation

$$\mathcal{J}^{\mathcal{B}}_{\text{HighEnt}}(\theta) = \mathbb{E}_{\mathcal{B} \sim \mathcal{D},\,(q,a) \sim \mathcal{B},\, \{o^i\}_{i=1}^{G} \sim \pi_{\text{old}}(\cdot|q)} \left[ \frac{1}{\sum_{i=1}^{G} |o^i|} \sum_{i=1}^{G} \sum_{t=1}^{|o^i|} \mathbb{I}\left[H^i_t \geq \tau^{\mathcal{B}}_{\rho} \right] \cdot \min\left(r^i_t(\theta) \hat{A}^i_t,\ \text{clip}(r^i_t(\theta), 1 - \epsilon_{\text{low}}, 1 + \epsilon_{\text{high}}) \hat{A}^i_t \right) \right], \quad \text{s.t. } 0 < \left|\left\{ o^i \mid \text{is\_equivalent}(a, o^i) \right\}\right| < G$$

The key difference is the term

$$\mathbb{I}\left[H^i_t \geq \tau^{\mathcal{B}}_{\rho} \right]$$

From the paper

Here, $H^i_t$ denotes the entropy of token t in response i, $\mathbb{I}[·]$ is the indicator function that evaluates to 1 if the
condition inside holds and 0 otherwise, ρ ∈ (0, 1] is a predefined ratio specifying the top proportion of high-entropy tokens to be selected within a batch, and $\tau^{\mathcal{B}}$ is the corresponding entropy threshold within
the batch B such that only tokens with $H^i_t ≥ \tau^{\mathcal{B}}_{\rho}$ , comprising the top-ρ fraction of all tokens in the batch, are used to compute the gradient.

Entropy is calculated as normal via the formula

$$H_t := - \sum_{j=1}^{V} p_{t,j} \log p_{t,j}$$

The paper applies the entropy mask to the DAPO loss function in their experiments, but I think we can leave it to the user to decide which loss i.e. GRPO, Dr. GRPO or DAPO to apply it to.

The paper finds that the best threshold is to keep the top-20% of tokens based on their entropy.

We ultimately improve RLVR by restricting policy gradient updates to
forking tokens and uncover a finding even beyond the 80/20 rule: utilizing only 20% of
the tokens while maintaining performance comparable to full-gradient updates on the
Qwen3-8B base model and significantly surpassing full-gradient updates on the Qwen3-
32B (+11.04 on AIME’25 and +7.71 on AIME’24) and Qwen3-14B (+4.79 on AIME’25 and
+5.21 on AIME’24)

I didn't run the vllm tests inside test_grpo_trainer.py since my machine/vm didn't have access to a gpu.

Fixes #3555

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@LeonEricsson
Copy link
Collaborator
LeonEricsson commented Jun 12, 2025

You should be able to calculate the entropy directly from log probs as

entropy = -(per_token_logps.exp() * per_token_logps).sum(-1)

which means we don't have to modify _get_per_token_logps. I noted however that you didn't apply temperature scaling when calculating your entropy, which my proposed solution does. As I see it temperature scaling remains an unresolved variable: the paper evaluates entropy masking only at T = 1, so the optimal percentile thresholds for other temperatures are unknown. The same caveat applies to model choice—results are reported solely for the Qwen-3 series, and the threshold may shift substantially for other architectures. In practice, the 20 % rule should be treated as a starting heuristic; users will need to experiment with the temperature-model-percentile combination that best fits their own setup.

@pramodith
Copy link
Contributor Author
pramodith commented Jun 12, 2025

Ah, yeah that's a good point let me update the entropy calculation to use the logprobs directly. Yep the temperature and 20% threshold should be hyper-parameters that are tuned by the user.

@pramodith
Copy link
Contributor Author

Also just realized that the temperature parameter won't affect the ranking order of entropy since all positions are affected by the same temperature, so the temp param shouldn't matter here.

@pramodith
Copy link
Contributor Author
pramodith commented Jun 12, 2025

You should be able to calculate the entropy directly from log probs as

entropy = -(per_token_logps.exp() * per_token_logps).sum(-1)

Actually this isn't true the per_token_logps only contain the log probs of the token in the sampled response. Entropy needs to be calculated over the probability distribution over the entire vocabulary per seq position in the response, so we do need to update the _get_per_token_logps since it uses the selective_log_softmax function to gather the probs of just the tokens in sampled response.

@LeonEricsson
Copy link
Collaborator
LeonEricsson commented Jun 12, 2025

You should be able to calculate the entropy directly from log probs as

entropy = -(per_token_logps.exp() * per_token_logps).sum(-1)

Actually this isn't true the per_token_logps only contain the log probs of the token in the sampled response. Entropy needs to be calculated over the probability distribution over the entire vocabulary per seq position in the response, so we do need to update the _get_per_token_logps since it uses the selective_log_softmax function to gather the probs of just the tokens in sampled response.

yes of course, my bad. not really a fan of the proposed refactor of _get_per_token_logps so trying to think of an alternative solution. Things of concern: 1) logits is a big tensor, clogs VRAM. 2) performing torch.softmax(logits, dim=-1) will spike memory usage. See #3390 for discussion. But maybe we just have to bite the bullet here, and in this case we should definitely only perform the entropy calculation when necessary

@qgallouedec
Copy link
Member
qgallouedec commented Jun 13, 2025

Nice! Thanks!
Yes, as @LeonEricsson says, the tricky part is the memory peak of the softmax. There as to be a way to compute the entropy efficiently, similarly to what we do with selective_log_softmax.

Another recommendation: 1 argument is probably enough.

- if self.filter_on_entropy:
+ if self.token_entropy_percentile_threshold < 1.0:

and add that the recommended value is 0.2 in the documentation.

@pramodith
Copy link
Contributor Author

Cool, let me look into making the entropy calculation less memory intensive.

@pramodith
Copy link
Contributor Author

Updated the code to make sure that only a mini-batch of logits are materialized at any given point of time and entropies for those mini-batches of logits are optionally calculated. The rowise_entropy function is inspired by the selective_log_softmax function to calculate entropy per row of the mini-batch. Introduces a bit of redundant compute because we compute the log_softmax twice in entropy calc and selective_log_softmax, but I guess its worth the tradeoff with the gains in clean code and memory.

@LeonEricsson
Copy link
Collaborator
LeonEricsson commented Jun 16, 2025

Updated the code to make sure that only a mini-batch of logits are materialized at any given point of time and entropies for those mini-batches of logits are optionally calculated. The rowise_entropy function is inspired by the selective_log_softmax function to calculate entropy per row of the mini-batch. Introduces a bit of redundant compute because we compute the log_softmax twice in entropy calc and selective_log_softmax, but I guess its worth the tradeoff with the gains in clean code and memory.

nice work. left a few minor comments

@pramodith
Copy link
Contributor Author

Thanks for the review Leon! Made the suggested changes.

@pramodith
Copy link
Contributor Author

@qgallouedec please take another look at the PR when you have the time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add entropy based token filtering to the GRPO Trainer/Loss Function.
3 participants
0