-
Notifications
You must be signed in to change notification settings - Fork 365
feat: support embedding_bag converter (1D input) #2395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support embedding_bag converter (1D input) #2395
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comments for suggestions on improved type enforcement. It seems for this specific case, some operators have args[2]
as a proper input, so it might be challenging to convert more generally with this numpy restriction.
Thanks for the review. The issues above have been addressed! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, pending CI!
Description
Support embedding_bag converter. Currently, only 1D input is supported.
schema: https://github.com/pytorch/pytorch/blob/bdecdfd202df3fa25fd9998070fd19fee4b14971/aten/src/ATen/native/native_functions.yaml#L2251
pytorch doc: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding_bag.html#torch-nn-functional-embedding-bag
Note:
offsets
is only used when input is 1D. If input is 2D of shape (B, N), it will be treated as B bags (sequences) each of fixed length N, and this will return B values aggregated in a way depending on the mode.offsets
is ignored and required to be None in this case. However, according to the schema,offsets
is required for input with any dimensions. There's no place describing how it works. There's a discussion in pytorch repooffsets
to be ndarray or torch tensor because we need to access data in it, butoffsets
could beITensor
in some cases, which cannot be accessed.Fixes #2345
Type of change
Checklist: