Closed
Description
🐛 Describe the bug
The doc of CELU()
says the types of alpha
and inplace
argument are float
and bool
respectively as shown below:
But alpha
argument with int
, complex
or bool
and inplace
argument with int
, complex
and float
work against the doc as shown below:
import torch
from torch import nn
my_tensor = torch.tensor([-1., 0., 1.])
celu = nn.CELU(alpha=1, inplace=1)
celu(input=my_tensor)
# tensor([-0.6321, 0.0000, 1.0000])
my_tensor = torch.tensor([-1., 0., 1.])
celu = nn.CELU(alpha=1.+0.j, inplace=1.+0.j)
celu(input=my_tensor)
# tensor([-0.6321, 0.0000, 1.0000])
my_tensor = torch.tensor([-1., 0., 1.])
celu = nn.CELU(alpha=True, inplace=1.)
celu(input=my_tensor)
# tensor([-0.6321, 0.0000, 1.0000])
Versions
import torch
torch.__version__ # 2.3.1+cu121
cc @svekars @sekyondaMeta @AlannaBurke @albanD @brycebortree