-
Notifications
You must be signed in to change notification settings - Fork 17
Compatibility with DynamicExpressions.jl #594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compatibility with DynamicExpressions.jl #594
Conversation
Thanks Miles, I'll take a look as soon as I could. |
Thx. Actually I think I just figured out how to fix it. It just needed the following in the _rrule_getfield_common: fdata_for_output = if field_sym === :val
Mooncake.fdata(pt.val)
elseif field_sym === :children
map(value_primal, pt.children) do child_p, child_t
if child_t isa Mooncake.NoTangent
Mooncake.uninit_fdata(child_p)
else
Mooncake.FData(Mooncake.fdata(child_t))
end
end
else
Mooncake.NoFData()
end So should be ready for a proper review now! |
I tried to integrate |
Codecov ReportAttention: Patch coverage is
📢 Thoughts on this report? Let us know! |
I am deeply confused by the error: caused by: StackOverflowError:
Stacktrace:
[1] macro expansion
@ ~/work/Mooncake.jl/Mooncake.jl/src/tangents.jl:435 [inlined]
[2] macro expansion
@ ./none:0 [inlined]
[3] tangent_type(::Type{Node{Float64, 3}})
@ Mooncake ./none:0
--- the above 3 lines are repeated 1 more time ---
[7] macro expansion
@ ~/work/Mooncake.jl/Mooncake.jl/src/tangents.jl:388 [inlined]
[8] macro expansion
@ ./none:0 [inlined]
[9] tangent_type(::Type{Tuple{DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}}})
@ Mooncake ./none:0
--- the above 9 lines are repeated 26660 more times ---
[239950] macro expansion
@ ~/work/Mooncake.jl/Mooncake.jl/src/tangents.jl:435 [inlined]
[239951] macro expansion
@ ./none:0 [inlined] Is this an internal stack overflow? The |
MWE using Mooncake, DynamicExpressions, DifferentiationInterface
operators = OperatorEnum(1 => (cos, sin, exp, log, abs), 2 => (+, -, *, /), 3 => (fma, max))
# Define components of expression
x1 = Expression(Node{Float64,3}(; feature=1); operators)
# Build expression object
expr = max(x1, x1, x1)
make_eval_sum(expr) = X -> sum(expr(X)) # (To avoid full recompilation)
backend = AutoMooncake(; config=nothing)
gradient(make_eval_sum(expr), backend, randn(1, 20)) which gets an unexpected stack overflow from the macro expansion. |
Having spent the last couple of days looking into this, I haven't fully resolved the issue, but I have a lead. https://github.com/MilesCranmer/Mooncake.jl/blob/7f7f7ff3106759394454a1866c719d8fdf4f19e1/ext/MooncakeDynamicExpressionsExt.jl#L41-L44 wasn't called because the type is not matched correctly. When I changed the function to function Mooncake.tangent_type(::Type{T}) where {T<:Nullable{<:AbstractExpressionNode}}
N = T.parameters[1]
@assert N <: AbstractExpressionNode
D = N.parameters[2]
Tparam = N.parameters[1]
Tv = Mooncake.tangent_type(Tparam)
return Union{@NamedTuple{null::NoTangent, x::TangentNode{Tv,D}},NoTangent}
end the stackoverflow went away. The error becomes error message
So I added function Mooncake.zero_tangent_internal(
x::Nullable{N}, dict::Mooncake.MaybeCache
) where {N<:AbstractExpressionNode}
if x.null
return NoTangent()
else
node_tangent = Mooncake.zero_tangent_internal(x.x, dict)
return (; null=NoTangent(), x=node_tangent)
end
end Then the error becomes error messageERROR: TypeError: in typeassert, expected Mooncake.CoDual{Tuple{DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}}, Union{Mooncake.NoFData, Tuple{Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}, Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}, Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}}}}, got a value of type Mooncake.CoDual{Tuple{DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}}, Tuple{Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}, Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}, Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}}}
Stacktrace:
[1] call_mapreducer
@ ~/.julia/packages/DynamicExpressions/Kn3mj/src/base.jl:123 [inlined]
[2] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
@ Base.Experimental ./<missing>:0
[3] DerivedRule
@ ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966 [inlined]
[4] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
@ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1827
[5] LazyDerivedRule
@ ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1822 [inlined]
[6] AbstractExpression
@ ~/.julia/packages/DynamicExpressions/Kn3mj/src/Expression.jl:515 [inlined]
[7] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
@ Base.Experimental ./<missing>:0
[8] (::Mooncake.DerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
@ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966
[9] (::Mooncake.DynamicDerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
@ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1739
[10] (::Mooncake.RRuleZeroWrapper{…})(f::Mooncake.CoDual{…}, args::Mooncake.CoDual{…})
@ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:302
[11] #13
@ ~/TuringLang/Mooncake.jl/test_pr594/recreating_error_test.jl:24 [inlined]
[12] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
@ Base.Experimental ./<missing>:0
[13] (::Mooncake.DerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
@ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966
[14] prepare_gradient_cache(::Function, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
@ Mooncake ~/TuringLang/Mooncake.jl/src/interface.jl:485
[15] prepare_gradient_cache
@ ~/TuringLang/Mooncake.jl/src/interface.jl:482 [inlined]
[16] prepare_gradient_nokwarg(::Val{true}, ::var"#13#14", ::AutoMooncake{Nothing}, ::Matrix{Float64})
@ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/alBlj/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:114
[17] gradient(::var"#13#14", ::AutoMooncake{Nothing}, ::Matrix{Float64})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/alBlj/src/first_order/gradient.jl:62
[18] top-level scope
@ ~/TuringLang/Mooncake.jl/test_pr594/recreating_error_test.jl:25
Some type information was truncated. Use `show(err)` to see complete types. (in a more human readable form) # Expected type:
Mooncake.CoDual{
Tuple{Nullable{Node{Float64,3}}, Nullable{Node{Float64,3}}, Nullable{Node{Float64,3}}},
Union{
Mooncake.NoFData,
Tuple{
Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}},
Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}},
Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}}
}
}
}
# Actual type received:
Mooncake.CoDual{
Tuple{Nullable{Node{Float64,3}}, Nullable{Node{Float64,3}}, Nullable{Node{Float64,3}}},
Tuple{
Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}},
Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}},
Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}}
}
} which I am not able to resolve at the moment. (I suspect that Mooncake doesn't particularly like the Interestingly, the code without my modifications would work for using Mooncake, DynamicExpressions, DifferentiationInterface
operators = OperatorEnum(
1 => (cos, sin, exp, log, abs), 2 => (+, -, *, /, max, min), 3 => (fma, max, min)
)
x1 = Expression(Node{Float64,3}(; feature=1); operators)
test_input = randn(1, 20)
backend = AutoMooncake(; config=nothing)
expr_binary = max(x1, x1)
f = X -> sum(expr_binary(X))
result = gradient(f, backend, test_input) which tripped me for a bit. I don't have a good answer why, but it might be a helpful piece of info. |
Thanks so much for looking into this. It is really appreciated given that I don't have much knowledge about Mooncake's internals!
It's very weird. It seems the stack overflow error only shows up when using a degree 3 node like Which is strange because technically they all go through the exact same function call regardless of degree: https://github.com/SymbolicML/DynamicExpressions.jl/blob/b80be2a3158e8175752ae15972962a5d1cc8a53c/src/Evaluate.jl#L292-L303 Is there a specific branch in the tangent calculation that gets triggered for 3+ argument functions that differs from others? |
This is the question I wish I had a good answer for. Sorry to say, I haven't fully internalized Mooncake's code, so couldn't say much more on the top of my head. But it's worth looking into once more free time shows up! |
I think I might (?) have just solved it. I applied this patch to DynamicExpressions.jl: @unstable function _get_return_type(tree, cX, operators, eval_options)
# public Julia API version of `Core.Compiler.return_type(_eval_tree_array, typeof((tree, cX, operators, eval_options)))`
- return eltype([_eval_tree_array(tree, cX, operators, eval_options) for _ in 1:0])
+ return Core.Compiler.return_type(_eval_tree_array, typeof((tree, cX, operators, eval_options)))
end Then the stack overflow goes away. Basically, this While this does result in a recursive type inference, the forward pass handles it, so I am not sure why the reverse pass is unable to. Also note that I use the |
@yebai Here is an even more minimal MWE. First, we can see that julia> test_data(Random.default_rng(), Node{Float64}(; feature=1))
#= (everything passes) =# However, if we embed this into any other struct, things start to break: julia> struct A{T}; x::T; end;
julia> test_data(Random.default_rng(), A(Node{Float64}(; feature=1)))
Test Summary: | Pass Total Time
=== | 42 42 0.0s
Test Summary: | Pass Total Time
ifelse | 92 92 0.0s
Test Summary: | Pass Total Time
sizeof | 19 19 0.0s
Test Summary: | Pass Total Time
isa | 63 63 0.0s
Test Summary: | Pass Total Time
tuple | 63 63 0.0s
Test Summary: | Pass Total Time
typeassert | 21 21 0.0s
Test Summary: | Pass Total Time
typeof | 19 19 0.0s
Test Summary: | Pass Total Time
getfield | 88 88 0.0s
Test Summary: | Pass Total Time
lgetfield | 88 88 0.0s
StackOverflowError()
args = (A{Node{Float64, 2}}, x1): Error During Test at /Users/mcranmer/PermaDocuments/Mooncake.jl/src/test_utils.jl:1263
Got exception outside of a @test
ArgumentError: rule for Mooncake.CoDual{typeof(Mooncake._new_), Mooncake.NoFData} with argument types Tuple{Mooncake.CoDual{Type{A{Node{Float64, 2}}}, Mooncake.NoFData}, Mooncake.CoDual{Node{Float64, 2}, MooncakeDynamicExpressionsExt.TangentNode{Float64, 2}}} does not run.
Stacktrace:
[1] test_rrule_interface(::Any, ::Any, ::Vararg{Any}; rule::Function)
@ Mooncake.TestUtils ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:532
[2] test_rule(::TaskLocalRNG, ::typeof(Mooncake._new_), ::Vararg{Any}; interface_only::Bool, is_primitive::Bool, perf_flag::Symbol, interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, debug_mode::Bool, unsafe_perturb::Bool)
@ Mooncake.TestUtils ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:711
[3] macro expansion
@ ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:1264 [inlined]
[4] macro expansion
@ ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/Test.jl:1793 [inlined]
[5] macro expansion
@ ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:1263 [inlined]
[6] macro expansion
@ ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/Test.jl:1793 [inlined]
[7] test_rule_and_type_interactions(rng::AbstractRNG, p::A{Node{Float64, 2}})
@ Mooncake.TestUtils ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:1261
[8] #test_data#81
@ ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:1370 [inlined]
[9] test_data(rng::TaskLocalRNG, p::A{Node{Float64, 2}})
@ Mooncake.TestUtils ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:1367
[10] top-level scope
@ REPL[35]:1
[11] eval
@ ./boot.jl:430 [inlined]
[12] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
@ REPL ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/REPL/src/REPL.jl:261
[13] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
@ REPL ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/REPL/src/REPL.jl:368
[14] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
@ REPL ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/REPL/src/REPL.jl:343
[15] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
@ REPL ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/REPL/src/REPL.jl:500
[16] run_repl(repl::REPL.AbstractREPL, consumer::Any)
@ REPL ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/REPL/src/REPL.jl:486
[17] (::Base.var"#1150#1152"{Bool, Symbol, Bool})(REPL::Module)
@ Base ./client.jl:446
[18] #invokelatest#2
@ ./essentials.jl:1055 [inlined]
[19] invokelatest
@ ./essentials.jl:1052 [inlined]
[20] run_main_repl(interactive::Bool, quiet::Bool, banner::Symbol, history_file::Bool, color_set::Bool)
@ Base ./client.jl:430
[21] repl_main
@ ./client.jl:567 [inlined]
[22] _start()
@ Base ./client.jl:541
caused by: StackOverflowError:
Stacktrace:
[1] tangent_type(::Type{Nullable{Node{Float64, 2}}})
@ Mooncake ./none:0
[2] macro expansion
@ ~/PermaDocuments/Mooncake.jl/src/tangents.jl:388 [inlined]
[3] macro expansion
@ ./none:0 [inlined]
[4] tangent_type(::Type{Tuple{Nullable{Node{Float64, 2}}, Nullable{Node{Float64, 2}}}})
@ Mooncake ./none:0
[5] macro expansion
@ ~/PermaDocuments/Mooncake.jl/src/tangents.jl:435 [inlined]
--- the above 3 lines are repeated 1 more time ---
[9] macro expansion
@ ./none:0 [inlined]--- the above 9 lines are repeated 26660 more times ---
Test Summary: | Pass Error Total Time
_new_ | 4 1 5 3.0s
args = (A{Node{Float64, 2}}, x1) | 4 1 5 3.0s
ERROR: Some tests did not pass: 4 passed, 0 failed, 1 errored, 0 broken. Even though |
I think there is a world age issue with this generic Lines 260 to 262 in ddef072
This generated function Lines 364 to 408 in ddef072
I might not have enough time today. Do you want to give it a try? |
I think part of the issue was this: #606. Which might help explain why I was getting confusing stack traces this whole time |
Ok with #606 (cherry picked here) and a few missing methods implemented, this seems to all be working! Ready for new review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me.
bump Mooncake version?
Signed-off-by: Miles Cranmer <miles.cranmer@gmail.com>
Done |
@MilesCranmer, free to merge if CI passes. |
Note that this also includes the commits in #606. Do you want to merge that separately or with this? |
Squash or regular merge? Seems my account can only create squashes |
Squash is preferred to keep the history simple. |
This is an attempt to define a extension for DynamicExpressions.jl recursive tree objects.
This assumes the v2 API which is slightly different than the one in #428. In particular, v2 of DynamicExpressions introduces the following change to
Node{T}
to permit arbitrary nodes of arityD
rather than only binary nodes:where
Nullable{T}
is a simple containernull::Bool; x::T
to prevent invalid accesses.I have been struggling to get this working based off of the example in #428 and would really appreciate help getting the final part completed as I have spent about as much time as I could dedicate to it.
Here is the current error:
cc @yebai @sunxd3