8000 Compatibility with DynamicExpressions.jl by MilesCranmer · Pull Request #594 · chalk-lab/Mooncake.jl · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Compatibility with DynamicExpressions.jl #594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Jun 24, 2025

Conversation

MilesCranmer
Copy link
Collaborator

This is an attempt to define a extension for DynamicExpressions.jl recursive tree objects.

This assumes the v2 API which is slightly different than the one in #428. In particular, v2 of DynamicExpressions introduces the following change to Node{T} to permit arbitrary nodes of arity D rather than only binary nodes:

- mutable struct Node{T} <: AbstractExpressionNode{T}
+ mutable struct Node{T,D} <: AbstractExpressionNode{T,D}
       degree::UInt8
       constant::Bool
       val::T
       feature::UInt16
       op::UInt8
-      l::Node{T,D}
-      r::Node{T,D}
+      children::NTuple{D,Nullable{Node{T,D}}}
  end

where Nullable{T} is a simple container null::Bool; x::T to prevent invalid accesses.

I have been struggling to get this working based off of the example in #428 and would really appreciate help getting the final part completed as I have spent about as much time as I could dedicate to it.

Here is the current error:

julia> using Mooncake, DynamicExpressions, DifferentiationInterface, Random

julia> operators = OperatorEnum(1 => (cos, sin), 2 => (+, -, *, /));

julia> x1, x2 = (
           Expression(Node{Float64}(; feature=i); operators) for i in 1:2
       );

julia> f = x1 + cos(x2 - 0.2) + 0.5
(x1 + cos(x2 - 0.2)) + 0.5

julia> X = randn(MersenneTwister(0), 3, 100);

julia> eval_sum = let f = f
           X -> sum(f(X))
       end;

julia> backend = AutoMooncake(; config=nothing);

julia> prep = prepare_gradient(eval_sum, backend, X)
ERROR: TypeError: in typeassert, expected Mooncake.CoDual{Tuple{DynamicExpressions.UtilsModule.Nullable{Node{Float64, 2}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 2}}}, Tuple{Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 2}}}, Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 2}}}}}, got a value of type Mooncake.CoDual{Tuple{DynamicExpressions.UtilsModule.Nullable{Node{Float64, 2}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 2}}}, Tuple{@NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 2}}, @NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 2}}}}
Stacktrace:
  [1] #_#23
    @ ~/.julia/packages/DynamicExpressions/Kn3mj/src/Expression.jl:518 [inlined]
  [2] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
  [3] DerivedRule
    @ ~/PermaDocuments/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966 [inlined]
  [4] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
    @ Mooncake ~/PermaDocuments/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1827
  [5] LazyDerivedRule
    @ ~/PermaDocuments/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1822 [inlined]
  [6] #13
    @ ./REPL[6]:2 [inlined]
  [7] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
  [8] (::Mooncake.DerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/PermaDocuments/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966
  [9] prepare_gradient_cache(::Function, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
    @ Mooncake ~/PermaDocuments/Mooncake.jl/src/interface.jl:485
 [10] prepare_gradient_cache
    @ ~/PermaDocuments/Mooncake.jl/src/interface.jl:482 [inlined]
 [11] prepare_gradient_nokwarg(::Val{true}, ::var"#13#14"{Expression{}}, ::AutoMooncake{Nothing}, ::Matrix{Float64})
    @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/alBlj/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:114
 [12] #prepare_gradient#46
    @ ~/.julia/packages/DifferentiationInterface/alBlj/src/first_order/gradient.jl:11 [inlined]
 [13] prepare_gradient(::var"#13#14"{Expression{}}, ::AutoMooncake{Nothing}, ::Matrix{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/alBlj/src/first_order/gradient.jl:8
 [14] top-level scope
    @ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.

cc @yebai @sunxd3

@sunxd3
Copy link
Collaborator
sunxd3 commented Jun 16, 2025

Thanks Miles, I'll take a look as soon as I could.

@MilesCranmer MilesCranmer changed the title WIP: compatibility with DynamicExpressions.jl Compatibility with DynamicExpressions.jl Jun 16, 2025
@MilesCranmer
Copy link
Collaborator Author

Thx. Actually I think I just figured out how to fix it. It just needed the following in the _rrule_getfield_common:

    fdata_for_output = if field_sym === :val
        Mooncake.fdata(pt.val)
    elseif field_sym === :children
        map(value_primal, pt.children) do child_p, child_t
            if child_t isa Mooncake.NoTangent
                Mooncake.uninit_fdata(child_p)
            else
                Mooncake.FData(Mooncake.fdata(child_t))
            end
        end
    else
        Mooncake.NoFData()
    end

So should be ready for a proper review now!

@MilesCranmer
Copy link
Collaborator Author

I tried to integrate test_rule but am getting a few remaining stack overflows.... Might need some help

Copy link
codecov bot commented Jun 16, 2025

Codecov Report

Attention: Patch coverage is 84.23913% with 58 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/MooncakeDynamicExpressionsExt.jl 84.50% 55 Missing ⚠️
src/rrules/new.jl 76.92% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

@MilesCranmer
Copy link
Collaborator Author

I am deeply confused by the error:

 caused by: StackOverflowError:
  Stacktrace:
        [1] macro expansion
          @ ~/work/Mooncake.jl/Mooncake.jl/src/tangents.jl:435 [inlined]
        [2] macro expansion
          @ ./none:0 [inlined]
        [3] tangent_type(::Type{Node{Float64, 3}})
          @ Mooncake ./none:0
  --- the above 3 lines are repeated 1 more time ---
        [7] macro expansion
          @ ~/work/Mooncake.jl/Mooncake.jl/src/tangents.jl:388 [inlined]
        [8] macro expansion
          @ ./none:0 [inlined]
        [9] tangent_type(::Type{Tuple{DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}}})
          @ Mooncake ./none:0
  --- the above 9 lines are repeated 26660 more times ---
   [239950] macro expansion
          @ ~/work/Mooncake.jl/Mooncake.jl/src/tangents.jl:435 [inlined]
   [239951] macro expansion
          @ ./none:0 [inlined]

Is this an internal stack overflow? The tangent_type call itself should not hit a stackoverflow anymore... So I am very confused by this.

@MilesCranmer
Copy link
Collaborator Author
MilesCranmer commented Jun 16, 2025

MWE

using Mooncake, DynamicExpressions, DifferentiationInterface

operators = OperatorEnum(1 => (cos, sin, exp, log, abs), 2 => (+, -, *, /), 3 => (fma, max))

# Define components of expression
x1 = Expression(Node{Float64,3}(; feature=1); operators)

# Build expression object
expr = max(x1, x1, x1)

make_eval_sum(expr) = X -> sum(expr(X))  # (To avoid full recompilation)

backend = AutoMooncake(; config=nothing)
gradient(make_eval_sum(expr), backend, randn(1, 20))

which gets an unexpected stack overflow from the macro expansion.

@yebai yebai requested a review from sunxd3 June 19, 2025 20:57
@sunxd3
Copy link
Collaborator
sunxd3 commented Jun 20, 2025

Having spent the last couple of days looking into this, I haven't fully resolved the issue, but I have a lead.

https://github.com/MilesCranmer/Mooncake.jl/blob/7f7f7ff3106759394454a1866c719d8fdf4f19e1/ext/MooncakeDynamicExpressionsExt.jl#L41-L44 wasn't called because the type is not matched correctly.

When I changed the function to

function Mooncake.tangent_type(::Type{T}) where {T<:Nullable{<:AbstractExpressionNode}}
    N = T.parameters[1]
    @assert N <: AbstractExpressionNode
    D = N.parameters[2]
    Tparam = N.parameters[1]
    Tv = Mooncake.tangent_type(Tparam)
    return Union{@NamedTuple{null::NoTangent, x::TangentNode{Tv,D}},NoTangent}
end

the stackoverflow went away.


The error becomes

error message
ERROR: MethodError: no method matching Union{…}(::@NamedTuple{…})
The type `Union{Mooncake.NoTangent, @NamedTuple{null::Mooncake.NoTangent, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}` exists, but no method is defined for this combination of argument types when trying to construct it.
Stacktrace:
  [1] macro expansion
    @ ~/TuringLang/Mooncake.jl/src/tangents.jl:549 [inlined]
  [2] macro expansion
    @ ./none:0 [inlined]
  [3] zero_tangent_internal(x::DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, d::IdDict{Any, Any})
    @ Mooncake ./none:0
  [4] zero_tangent(x::DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/tangents.jl:464
  [5] uninit_tangent
    @ ~/TuringLang/Mooncake.jl/src/tangents.jl:561 [inlined]
  [6] uninit_fdata(p::DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/fwds_rvs_data.jl:275
  [7] (::MooncakeDynamicExpressionsExt.var"#25#26")(child_p::DynamicExpressions.UtilsModule.Nullable{…}, child_t::Mooncake.NoTangent)
    @ MooncakeDynamicExpressionsExt ~/TuringLang/Mooncake.jl/ext/MooncakeDynamicExpressionsExt.jl:370
  [8] map
    @ ./tuple.jl:383 [inlined]
  [9] map
    @ ./tuple.jl:386 [inlined]
 [10] _rrule_getfield_common
    @ ~/TuringLang/Mooncake.jl/ext/MooncakeDynamicExpressionsExt.jl:368 [inlined]
 [11] rrule!!
    @ ~/TuringLang/Mooncake.jl/ext/MooncakeDynamicExpressionsExt.jl:391 [inlined]
 [12] call_mapreducer
    @ ~/.julia/packages/DynamicExpressions/Kn3mj/src/base.jl:123 [inlined]
 [13] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [14] DerivedRule
    @ ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966 [inlined]
 [15] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1827
 [16] LazyDerivedRule
    @ ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1822 [inlined]
 [17] AbstractExpression
    @ ~/.julia/packages/DynamicExpressions/Kn3mj/src/Expression.jl:515 [inlined]
 [18] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [19] (::Mooncake.DerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966
 [20] (::Mooncake.DynamicDerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1739
 [21] (::Mooncake.RRuleZeroWrapper{…})(f::Mooncake.CoDual{…}, args::Mooncake.CoDual{…})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:302
 [22] #11
    @ ~/TuringLang/Mooncake.jl/test_pr594/recreating_error_test.jl:21 [inlined]
 [23] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [24] (::Mooncake.DerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966
 [25] prepare_gradient_cache(::Function, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interface.jl:485
 [26] prepare_gradient_cache
    @ ~/TuringLang/Mooncake.jl/src/interface.jl:482 [inlined]
 [27] prepare_gradient_nokwarg(::Val{true}, ::var"#11#12", ::AutoMooncake{Nothing}, ::Matrix{Float64})
    @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/alBlj/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:114
 [28] gradient(::var"#11#12", ::AutoMooncake{Nothing}, ::Matrix{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/alBlj/src/first_order/gradient.jl:62
 [29] top-level scope
    @ ~/TuringLang/Mooncake.jl/test_pr594/recreating_error_test.jl:22
Some type information was truncated. Use `show(err)` to see complete types.

So I added

function Mooncake.zero_tangent_internal(
    x::Nullable{N}, dict::Mooncake.MaybeCache
) where {N<:AbstractExpressionNode}
    if x.null
        return NoTangent()
    else
        node_tangent = Mooncake.zero_tangent_internal(x.x, dict)
        return (; null=NoTangent(), x=node_tangent)
    end
end

Then the error becomes

error message
ERROR: TypeError: in typeassert, expected Mooncake.CoDual{Tuple{DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}}, Union{Mooncake.NoFData, Tuple{Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}, Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}, Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}}}}, got a value of type Mooncake.CoDual{Tuple{DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}, DynamicExpressions.UtilsModule.Nullable{Node{Float64, 3}}}, Tuple{Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}, Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}, Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::MooncakeDynamicExpressionsExt.TangentNode{Float64, 3}}}}}
Stacktrace:
  [1] call_mapreducer
    @ ~/.julia/packages/DynamicExpressions/Kn3mj/src/base.jl:123 [inlined]
  [2] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
  [3] DerivedRule
    @ ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966 [inlined]
  [4] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1827
  [5] LazyDerivedRule
    @ ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1822 [inlined]
  [6] AbstractExpression
    @ ~/.julia/packages/DynamicExpressions/Kn3mj/src/Expression.jl:515 [inlined]
  [7] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
  [8] (::Mooncake.DerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966
  [9] (::Mooncake.DynamicDerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1739
 [10] (::Mooncake.RRuleZeroWrapper{…})(f::Mooncake.CoDual{…}, args::Mooncake.CoDual{…})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:302
 [11] #13
    @ ~/TuringLang/Mooncake.jl/test_pr594/recreating_error_test.jl:24 [inlined]
 [12] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [13] (::Mooncake.DerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:966
 [14] prepare_gradient_cache(::Function, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
    @ Mooncake ~/TuringLang/Mooncake.jl/src/interface.jl:485
 [15] prepare_gradient_cache
    @ ~/TuringLang/Mooncake.jl/src/interface.jl:482 [inlined]
 [16] prepare_gradient_nokwarg(::Val{true}, ::var"#13#14", ::AutoMooncake{Nothing}, ::Matrix{Float64})
    @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/alBlj/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:114
 [17] gradient(::var"#13#14", ::AutoMooncake{Nothing}, ::Matrix{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/alBlj/src/first_order/gradient.jl:62
 [18] top-level scope
    @ ~/TuringLang/Mooncake.jl/test_pr594/recreating_error_test.jl:25
Some type information was truncated. Use `show(err)` to see complete types.

(in a more human readable form)

# Expected type:
  Mooncake.CoDual{
      Tuple{Nullable{Node{Float64,3}}, Nullable{Node{Float64,3}}, Nullable{Node{Float64,3}}},
      Union{
          Mooncake.NoFData,
          Tuple{
              Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}},
              Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}},
              Union{Mooncake.NoFData, @NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}}
          }
      }
  }

# Actual type received:
  Mooncake.CoDual{
      Tuple{Nullable{Node{Float64,3}}, Nullable{Node{Float64,3}}, Nullable{Node{Float64,3}}},
      Tuple{
          Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}},
          Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}},
          Mooncake.FData{@NamedTuple{null::Mooncake.NoFData, x::TangentNode{Float64,3}}}
      }
  }

which I am not able to resolve at the moment. (I suspect that Mooncake doesn't particularly like the tangent_type being a Union type. @willtebbutt)


Interestingly, the code without my modifications would work for

using Mooncake, DynamicExpressions, DifferentiationInterface

operators = OperatorEnum(
    1 => (cos, sin, exp, log, abs), 2 => (+, -, *, /, max, min), 3 => (fma, max, min)
)
x1 = Expression(Node{Float64,3}(; feature=1); operators)
test_input = randn(1, 20)
backend = AutoMooncake(; config=nothing)
expr_binary = max(x1, x1)
f = X -> sum(expr_binary(X))
result = gradient(f, backend, test_input)

which tripped me for a bit. I don't have a good answer why, but it might be a helpful piece of info.

@MilesCranmer
Copy link
Collaborator Author

Having spent the last couple of days looking into this, I haven't fully resolved the issue, but I have a lead.

Thanks so much for looking into this. It is really appreciated given that I don't have much knowledge about Mooncake's internals!

Interestingly, the code without my modifications would work for

using Mooncake, DynamicExpressions, DifferentiationInterface

operators = OperatorEnum(
    1 => (cos, sin, exp, log, abs), 2 => (+, -, *, /, max, min), 3 => (fma, max, min)
)
x1 = Expression(Node{Float64,3}(; feature=1); operators)
test_input = randn(1, 20)
backend = AutoMooncake(; config=nothing)
expr_binary = max(x1, x1)
f = X -> sum(expr_binary(X))
result = gradient(f, backend, test_input)

which tripped me for a bit. I don't have a good answer why, but it might be a helpful piece of info.

It's very weird. It seems the stack overflow error only shows up when using a degree 3 node like fma (or the 3 arg versions of max/min). Here you are using max with only 2 arguments so it works.

Which is strange because technically they all go through the exact same function call regardless of degree: https://github.com/SymbolicML/DynamicExpressions.jl/blob/b80be2a3158e8175752ae15972962a5d1cc8a53c/src/Evaluate.jl#L292-L303

Is there a specific branch in the tangent calculation that gets triggered for 3+ argument functions that differs from others?

@sunxd3
Copy link
Collaborator
sunxd3 commented Jun 20, 2025

Is there a specific branch in the tangent calculation that gets triggered for 3+ argument functions that differs from others?

This is the question I wish I had a good answer for. Sorry to say, I haven't fully internalized Mooncake's code, so couldn't say much more on the top of my head. But it's worth looking into once more free time shows up!

@MilesCranmer
Copy link
Collaborator Author

I think I might (?) have just solved it. I applied this patch to DynamicExpressions.jl:

   @unstable function _get_return_type(tree, cX, operators, eval_options)
      # public Julia API version of `Core.Compiler.return_type(_eval_tree_array, typeof((tree, cX, operators, eval_options)))`
-     return eltype([_eval_tree_array(tree, cX, operators, eval_options) for _ in 1:0])
+     return Core.Compiler.return_type(_eval_tree_array, typeof((tree, cX, operators, eval_options)))
   end

Then the stack overflow goes away.

Basically, this _get_return_type – whose only purpose is to help the compiler with inference – only gets triggered for degree 3+. I do this because higher degrees tend to be harder for the compiler to reliably infer the precise array return type, and this seems to help.

While this does result in a recursive type inference, the forward pass handles it, so I am not sure why the reverse pass is unable to.

Also note that I use the eltype([f() for _ in 1:0]) simply as a public API version of Core.Compiler.return_type(f) though it should be identical. I have no idea why switching to explicit Core.Compiler.return_type is enough to fix this. Maybe Mooncake has special treatment of either Core.Compiler.return_type or the list comprehension? @willtebbutt @yebai

@MilesCranmer
Copy link
Collaborator Author
MilesCranmer commented Jun 23, 2025

@yebai Here is an even more minimal MWE.

First, we can see that Node{Float64}(; feature=1) works perfectly:

julia> test_data(Random.default_rng(), Node{Float64}(; feature=1))
#= (everything passes) =#

However, if we embed this into any other struct, things start to break:

julia> struct A{T}; x::T; end;

julia> test_data(Random.default_rng(), A(Node{Float64}(; feature=1)))
Test Summary: | Pass  Total  Time
===           |   42     42  0.0s
Test Summary: | Pass  Total  Time
ifelse        |   92     92  0.0s
Test Summary: | Pass  Total  Time
sizeof        |   19     19  0.0s
Test Summary: | Pass  Total  Time
isa           |   63     63  0.0s
Test Summary: | Pass  Total  Time
tuple         |   63     63  0.0s
Test Summary: | Pass  Total  Time
typeassert    |   21     21  0.0s
Test Summary: | Pass  Total  Time
typeof        |   19     19  0.0s
Test Summary: | Pass  Total  Time
getfield      |   88     88  0.0s
Test Summary: | Pass  Total  Time
lgetfield     |   88     88  0.0s
StackOverflowError()

args = (A{Node{Float64, 2}}, x1): Error During Test at /Users/mcranmer/PermaDocuments/Mooncake.jl/src/test_utils.jl:1263
  Got exception outside of a @test
  ArgumentError: rule for Mooncake.CoDual{typeof(Mooncake._new_), Mooncake.NoFData} with argument types Tuple{Mooncake.CoDual{Type{A{Node{Float64, 2}}}, Mooncake.NoFData}, Mooncake.CoDual{Node{Float64, 2}, MooncakeDynamicExpressionsExt.TangentNode{Float64, 2}}} does not run.
  Stacktrace:
    [1] test_rrule_interface(::Any, ::Any, ::Vararg{Any}; rule::Function)
      @ Mooncake.TestUtils ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:532
    [2] test_rule(::TaskLocalRNG, ::typeof(Mooncake._new_), ::Vararg{Any}; interface_only::Bool, is_primitive::Bool, perf_flag::Symbol, interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, debug_mode::Bool, unsafe_perturb::Bool)
      @ Mooncake.TestUtils ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:711
    [3] macro expansion
      @ ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:1264 [inlined]
    [4] macro expansion
      @ ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/Test.jl:1793 [inlined]
    [5] macro expansion
      @ ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:1263 [inlined]
    [6] macro expansion
      @ ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/Test.jl:1793 [inlined]
    [7] test_rule_and_type_interactions(rng::AbstractRNG, p::A{Node{Float64, 2}})
      @ Mooncake.TestUtils ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:1261
    [8] #test_data#81
      @ ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:1370 [inlined]
    [9] test_data(rng::TaskLocalRNG, p::A{Node{Float64, 2}})
      @ Mooncake.TestUtils ~/PermaDocuments/Mooncake.jl/src/test_utils.jl:1367
   [10] top-level scope
      @ REPL[35]:1
   [11] eval
      @ ./boot.jl:430 [inlined]
   [12] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
      @ REPL ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/REPL/src/REPL.jl:261
   [13] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
      @ REPL ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/REPL/src/REPL.jl:368
   [14] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
      @ REPL ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/REPL/src/REPL.jl:343
   [15] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
      @ REPL ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/REPL/src/REPL.jl:500
   [16] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/REPL/src/REPL.jl:486
   [17] (::Base.var"#1150#1152"{Bool, Symbol, Bool})(REPL::Module)
      @ Base ./client.jl:446
   [18] #invokelatest#2
      @ ./essentials.jl:1055 [inlined]
   [19] invokelatest
      @ ./essentials.jl:1052 [inlined]
   [20] run_main_repl(interactive::Bool, quiet::Bool, banner::Symbol, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:430
   [21] repl_main
      @ ./client.jl:567 [inlined]
   [22] _start()
      @ Base ./client.jl:541
  
  caused by: StackOverflowError:
  Stacktrace:
        [1] tangent_type(::Type{Nullable{Node{Float64, 2}}})
          @ Mooncake ./none:0
        [2] macro expansion
          @ ~/PermaDocuments/Mooncake.jl/src/tangents.jl:388 [inlined]
        [3] macro expansion
          @ ./none:0 [inlined]
        [4] tangent_type(::Type{Tuple{Nullable{Node{Float64, 2}}, Nullable{Node{Float64, 2}}}})
          @ Mooncake ./none:0
        [5] macro expansion
          @ ~/PermaDocuments/Mooncake.jl/src/tangents.jl:435 [inlined]
  --- the above 3 lines are repeated 1 more time ---
        [9] macro expansion
          @ ./none:0 [inlined]--- the above 9 lines are repeated 26660 more times ---
Test Summary:                      | Pass  Error  Total  Time
_new_                              |    4      1      5  3.0s
  args = (A{Node{Float64, 2}}, x1) |    4      1      5  3.0s
ERROR: Some tests did not pass: 4 passed, 0 failed, 1 errored, 0 broken.

Even though A{T} itself is not recursive, it seems to break _new_ when its own type parameter references a recursive struct... Which I guess explains why Tuple{} wrapping node breaks things.

@yebai
Copy link
Member
yebai commented Jun 23, 2025

I think there is a world age issue with this generic _new_ implementation for types like A{T}. This world age issue implies that new functions defined after the definition of this _new_ function are invisible. Thus, when it tries to build tangent types for field types, it will fall back to Mooncake's generic tangent_type

Mooncake.jl/src/utils.jl

Lines 260 to 262 in ddef072

@inline @generated function _new_(::Type{T}, x::Vararg{Any,N}) where {T,N}
return Expr(:new, :T, map(n -> :(x[$n]), 1:N)...)
end

This generated function _new_ needs a simliar trick to

Mooncake.jl/src/tangents.jl

Lines 364 to 408 in ddef072

@foldable @generated function tangent_type(::Type{P}) where {N,P<:Tuple{Vararg{Any,N}}}
# As with other types, tangent type of Union is Union of tangent types.
P isa Union && return :(Union{tangent_type($(P.a)),tangent_type($(P.b))})
# Determine whether P isa a Tuple with a Vararg, e.g, Tuple, or Tuple{Float64, Vararg}.
# Need to exclude `UnionAll`s from this, by checking `isa(P, DataType)`, in order to
# ensure that `Base.datatype_fieldcount(P)` will run successfully.
isa(P, DataType) && !(@isdefined(N)) && return Any
# Tuple{} can only have `NoTangent` as its tangent type. As before, verify we don't have
# a UnionAll before running to ensure that datatype_fieldcount will run.
isa(P, DataType) && N == 0 && return NoTangent
# Expression to construct `Tuple` type containing tangent type for all fields.
tangent_type_exprs = map(n -> :(tangent_typ F438 e(fieldtype(P, $n))), 1:N)
tangent_types = Expr(:call, tuple, tangent_type_exprs...)
# Construct a Tuple type of the same length as `P`, containing all `NoTangent`s.
T_all_notangent = Tuple{Vararg{NoTangent,N}}
return quote
# Get tangent types for all fields. If they're all `NoTangent`, return `NoTangent`.
# i.e. if `P = Tuple{Int, Int}`, do not return `Tuple{NoTangent, NoTangent}`.
# Simplify and return `NoTangent`.
tangent_types = $tangent_types
T = Tuple{tangent_types...}
T <: $T_all_notangent && return NoTangent
# If exactly one of the field types is a Union, then split.
union_fields = _findall(Base.Fix2(isa, Union), tangent_types)
if length(union_fields) == 1 && all(tuple_map(isconcrete_or_union, tangent_types))
return split_union_tuple_type(tangent_types)
end
# If it's _possible_ for a subtype of `P` to have tangent type `NoTangent`, then we
# must account for that by returning the union of `NoTangent` and `T`. For example,
# if `P = Tuple{Any, Int}`, then `P2 = Tuple{Int, Int}` is a subtype. Since `P2` has
# tangent type `NoTangent`, it must be true that `NoTangent <: tangent_type(P)`. If,
# on the other hand, it's not possible for `NoTangent` to be the tangent type, e.g.
# for `Tuple{Float64, Any}`, then there's no need to take the union.
return $T_all_notangent <: T ? Union{T,NoTangent} : T
end
end

I might not have enough time today. Do you want to give it a try?

@MilesCranmer
Copy link
Collaborator Author

I think part of the issue was this: #606. Which might help explain why I was getting confusing stack traces this whole time

@MilesCranmer
Copy link
Collaborator Author

Ok with #606 (cherry picked here) and a few missing methods implemented, this seems to all be working! Ready for new review.

Copy link
Collaborator
@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me.

bump Mooncake version?

Signed-off-by: Miles Cranmer <miles.cranmer@gmail.com>
@MilesCranmer
Copy link
Collaborator Author

Done

@sunxd3 sunxd3 requested a review from yebai June 24, 2025 10:26
@yebai
Copy link
Member
yebai commented Jun 24, 2025

@MilesCranmer, free to merge if CI passes.

@MilesCranmer
Copy link
Collaborator Author

Note that this also includes the commits in #606. Do you want to merge that separately or with this?

@yebai
Copy link
Member
yebai commented Jun 24, 2025

Note that this also includes the commits in #606. Do you want to merge that separately or with this?

It's okay to merge this; @sunxd3 will add a MWE-based regression test later.

@MilesCranmer
Copy link
Collaborator Author

Squash or regular merge? Seems my account can only create squashes

@yebai
Copy link
Member
yebai commented Jun 24, 2025

Squash is preferred to keep the history simple.

@MilesCranmer MilesCranmer merged commit bf7aff9 into chalk-lab:main Jun 24, 2025
79 of 80 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0