8000 Small incompatibility between C/Cuda maskedFill / maskedSelect · Issue #231 · torch/torch7 · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Small incompatibility between C/Cuda maskedFill / maskedSelect #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
fmassa opened this issue May 12, 2015 · 6 comments
Open

Small incompatibility between C/Cuda maskedFill / maskedSelect #231

fmassa opened this issue May 12, 2015 · 6 comments

Comments

@fmassa
Copy link
Contributor
fmassa commented May 12, 2015

I almost posted this on torch/cutorch#70, but I think it should be addressed here.

There is a small incompatibility between C and CUDA versions of maskedFill and maskedSelect.
In the C version, the mask needs to be a ByteTensor, whereas CUDA version accepts both ByteTensor and CudaTensor.
This is probably motivated by the fact that the comparison operations in CUDA returns CudaTensors, while their C counterparts returns ByteTensors

By itself, that's not a big deal, but on C side, one needs two buffers when using comparison operations in order to use maskedFill or maskedSelect, one for the comparison operator (which is of type Tensor if one reuses a tensor for the output), and one ByteTensor for the masking.
An example is maybe clearer:

t = torch.rand(5)
mask = torch.Tensor()
mask_byte = torch.ByteTensor() -- only needed because of maskedFill / maskedSelect

mask:lt(t,0.5)  -- puts the result in mask

-- mask_byte:lt(t,0.5) -- doesn't work, but could also be a solution if it worked
mask_byte:resize(mask:size()):copy(mask)

print(t[mask]) -- prints nil
print(t[mask_byte])

With that in mind, wouldn't it be better to adapt the C version of maskedFill and maskedSelect to also accept the current source Tensor type for the mask, as in its CUDA version ? Or maybe we should accept the output of the comparison operators to also be Tensor instead of only ByteTensor ? Or both ?

@dominikgrewe
Copy link
Member

The underlying problem seems to be that torch comparison operations return different types of tensors depending on whether a result tensor is provided or not:

b = torch.rand(5)
a = b:lt(0.5)
print(a:type())  -- torch.ByteTensor

a = torch.Tensor()
a:lt(b, 0.5)  -- must be a torch.Tensor

This seems to be a deliberate decision (https://github.com/torch/torch7/blob/master/TensorMath.lua#L771), although I don't know why.
@soumith do you know why the behaviour is the way it is?

@soumith
Copy link
Member
soumith commented May 20, 2015

This decision precedes me. If I had to guess, it is because a mask is best stored in a ByteTensor than a larger tensor (memory efficiency).
On cuda, I guess because we do not have CudaByteTensor, Jeff implemented the masking only based on CudaTensor.

We do need to extend the cuda side (soon) to have more tensor types (at the very least, even if the tensor types dont have all the math implemented).

@dominikgrewe
Copy link
Member

But why does torch return a ByteTensor when no result tensor is provided,
but if one is provided it expects a Tensor of the same type as the input?

On Wed, 20 May 2015, 08:39 Soumith Chintala notifications@github.com
wrote:

This decision precedes me. If I had to guess, it is because a mask is best
stored in a ByteTensor than a larger tensor (memory efficiency).
On cuda, I guess because we do not have CudaByteTensor, Jeff implemented
the masking only based on CudaTensor.

We do need to extend the cuda side (soon) to have more tensor types (at
the very least, even if the tensor types dont have all the math
implemented).


Reply to this email directly or view it on GitHub
#231 (comment).

@timharley
Copy link
Contributor

@dominikgrewe It's down to the way that function dispatch is done by torch. It finds the correct C function to call by inspecting the metatable of the first argument. For example, there is a C function implementing eq for each Tensor type, each one stored in the metatable of the corresponding type. If the look up is done on a ByteTensor passed as the result, then the C function that is called is the method for ByteTensors, rather than the method for the type we care about.

torch.eq(result, tensor, 0) -- Looks up the "eq" metamethod on result
torch.eq(tensor, 0) -- Looks up the "eq" metamethod on tensor.
a:eq(b, 0.5) -- a is actually the result tensor here, not b, so to find the correct eq function for bs Tensor type, a must be the same type as b.

The code that does all this is autogenerated during the build with cwrap. This chunk is used to generate the function that decides how to dispatch the torch.eq call:
https://github.com/torch/torch7/blob/master/TensorMath.lua#L70

The underlying C functions that take a ByteTensor or the correctly typed functions both exist, so we just need to work out what is the correct one to call for any particular lua function call. By jiggling cwrap it should be possible to make torch.eq (and similar functions that use/create masks) dispatch based on the type of the second tensor argument if one is provided, but it's not going to be particularly easy!

The real underlying root of the problem is that while the macro based templating mechanism makes it easy to be generic across one Tensor type, it is hard to be generic across two types. The only function that works across types is :copy() and this is hand coded.

@timharley
Copy link
Contributor

@fmassa There is a way round:

local t = torch.rand(4)
local mask_byte = torch.ByteTensor(4)
local mask_double = torch.Tensor(4)

-- Both of these work:
torch.Tensor.lt(mask_byte, t, 0.5)
torch.Tensor.lt(mask_double, t, 0.5)

@fmassa
Copy link
Contributor Author
fmassa commented May 26, 2015

@timharley Thanks for the example.
Unfortunately this solution needs a bit more work for it to be generic. For example, to make it work for floats, one need to do

t = torch.rand(5):float()
mask_byte = torch.ByteTensor()

-- using torch.Tensor.lt doesn't work because
--we are accessing the lt function of the default (double) type
torch.FloatTensor.lt(mask_byte,t,0.5)

Futhermore, this solution doesn't work for CudaTensors (both the mask and t need to be CudaTensors for it to work).

I was aware of those problems of the comparison operator when I posted this issue, that's the main reason why I was proposing changing the behaviour of maskedFill/maskedSelect to also accept the same Tensor type as the input, instead of only ByteTensor.

But again, I don't know if that's the best solution for this problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants
0