Open
Description
🐛 Describe the bug
import torch
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor a, Tensor b, Tensor? c) -> (Tensor, Tensor?)",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.register_fake("mylib::foo")
def foo_impl(a, b, c):
res2 = None
if c is not None:
res2 = c + a + b
return a + b, res2
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor a, Tensor b, Tensor? c) -> (Tensor, Tensor?)",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.register_fake("mylib::foo")
def foo_impl(a, b, c):
res2 = None
if c is not None:
res2 = c + a + b
return a + b, res2
This gives us
RuntimeError: register_fake(...): the operator mylib::foo already has an fake impl registered at /data/users/yidi/pytorch/test.py:13.
This happens when i try to parametrize a unit test that registered a custom op. It may take some refactoring to make this work. For now, we could bypass the error by setting allow_override=True.
Versions
I'm on master