[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repeat vector #2409

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Conversation

Jafagervik
Copy link

RepeatVector Layer like in keras.

PS: I'm still new to contributing here, so if "trivial layers" like these are not requested, and rather something each user can implement on their own, then please let me know

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@Jafagervik Jafagervik marked this pull request as draft March 24, 2024 18:54
@Jafagervik Jafagervik marked this pull request as ready for review March 24, 2024 18:56
Comment on lines +891 to +893
expanded = reshape(x, (size(x)..., 1))
repeated = repeat(expanded, outer = (1, rv.n, 1))
return repeated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could just be

return repeat(x, size(x)..., rv.n)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would do this:

julia> RepeatVector(2)(rand(3)) |> size
(9, 2)

julia> RepeatVector(2)(rand(3,5)) |> size
(9, 25, 2)

julia> RepeatVector(2)(rand(3,5,7)) |> size
(9, 25, 49, 2)

Perhaps you meant:

julia> function (rv::RepeatVector)(x::AbstractArray{T}) where {T}
          repeated = repeat(x, ntuple(_->1, ndims(x))..., rv.n)
       end

julia> RepeatVector(2)(rand(3)) |> size
(3, 2)

julia> RepeatVector(2)(rand(3,5)) |> size
(3, 5, 2)

julia> RepeatVector(2)(rand(3,5,7)) |> size
(3, 5, 7, 2)

"""
RepeatVector(n::Int)

Repeat the input `n` times along the last dimension.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should specify that an extra dimension is added first.



"""
RepeatVector(n::Int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RepeatArray is maybe a better name.


# Examples
```jldoctest
julia> rv = RepeatVector(3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in a jldoctest block each line should also have the corresponding output, otherwise the doctests will fail

@CarloLucibello
Copy link
Member

I would be ok with adding a layer like this, although I don't like the naming and I can't think of a really good one.
Test should be added in the test/ folder and the layer should be referenced in docs/

@mcabbott
Copy link
Member

What is this used for?

And what is the intended behaviour, on arrays of various dimensions? I'm not sure the code matches the words, nor the example given.

julia> RepeatVector(2)(rand(3)) |> size
(3, 2, 1)

julia> RepeatVector(2)(rand(3,5)) |> size
(3, 10, 1)

julia> RepeatVector(2)(rand(3,5,7)) |> size
(3, 10, 7, 1)

julia> rv = RepeatVector(3);

julia> rv([1, 2, 3])  # example says 3×3 Matrix
3×3×1 Array{Int64, 3}:

@ToucheSir
Copy link
Member

The lack of clarity on how this layer should behave suggest to me that we might not want it as a built-in. Now, the good news is that Flux doesn't require a custom layer type for everything. In this case, it suffices to use a plain function/closure:

function repeatvector(x::AbstractArray, n::Int)
  expanded = reshape(x, (size(x)..., 1))
  repeated = repeat(expanded, outer = (1, rv.n, 1))
  return repeated
end
repeatvector(n::Int) = x -> repeatvector(x, n)
f = repeatvector(2)
# or, using built-in types:
f = Base.Fix2(repeatvector, n)

@darsnack
Copy link
Member

Going along Brian's point, even if we were to include this, I think it ought to be a function not a struct. Though I tend to lean towards not including it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants