-
-
Notifications
You must be signed in to change notification settings - Fork 649
feat(doctest): precision metric doctest #2306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@ydcjeff could you please review this PR |
Code looks good. @halsawadi Can you change the description of the PR to something like Thanks. |
def thresholded_output_transform(engine, output): | ||
y_pred, y = output | ||
y_pred = torch.round(y_pred) | ||
return y_pred, y | ||
engine = Engine(thresholded_output_transform) | ||
metric = Precision() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this is what we want to expose to our users. thresholded_output_transform
should remain the argument to Precision
and not for Engine
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually that should be process_function
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example shows how to binarize predictions for Precision metric only. Returning binarized preds from validator may be a limitation...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general we can improve the docstring for Precision (same for Accuracy and this comment: #2287 (comment)) and show all possible cases:
binary case
multiclass case
multilabel case
|
||
def thresholded_output_transform(output): | ||
def thresholded_output_transform(engine, output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def thresholded_output_transform(engine, output): | |
def process_fn(engine, output): |
y_pred, y = output | ||
y_pred = torch.round(y_pred) | ||
return y_pred, y | ||
engine = Engine(thresholded_output_transform) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
engine = Engine(thresholded_output_transform) | |
engine = Engine(process_fn) |
@halsawadi Thank you for your work. Are you still working on this ? |
@halsawadi Thank you for your help. I close this PR since the doctest for Precision has been merged. |
Fixes part of #2265
Description:
Add a doctest for Precision in the metrics module.
Check list: