Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChainRulesCore.ProjectTo creates sparse matrices of the wrong element type (drops Duals) #648

Open
ChrisRackauckas opened this issue Dec 30, 2023 · 3 comments
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Dec 30, 2023

MWE:

using Zygote, SparseArrays, ForwardDiff

x, v = rand(Float32, 5), rand(Float32, 5)
A = sprand(Float32, 5, 5, 0.5)
loss(_x) = sum(tanh.(A * _x))

T = typeof(ForwardDiff.Tag(nothing, eltype(x)))
y = ForwardDiff.Dual{T, eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
g = x -> first(Zygote.gradient(loss, x))
ForwardDiff.partials.(g(y), 1)
@mcabbott
Copy link
Member

Note that this MWE doesn't run:

julia> y = _default_autoback_hesvec_cache(x, v)
ERROR: UndefVarError: `_default_autoback_hesvec_cache` not defined in `Main`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
 [1] top-level scope
   @ REPL[199]:1
 [2] top-level scope
   @ ~/.julia/packages/Metal/qeZqc/src/initialization.jl:51

julia> ForwardDiff.partials.(g(y), 1)
ERROR: MethodError: no method matching Float32(::Dual{Nothing, Float32, 1})

@ChrisRackauckas
Copy link
Member Author

Fixed, just delete that extra line.

@mcabbott
Copy link
Member

mcabbott commented Dec 31, 2023

A less complicated way to trigger this seems to be:

julia> x |> summary  # from above
"5-element Vector{Float32}"

julia> Zygote.gradient(loss, x)
(Float32[0.711501, 0.9295027, 0.035282552, 0.9122769, 0.3412085],)

julia> Zygote.gradient(loss, x .+ Dual(0,1))
ERROR: MethodError: no method matching Float32(::Dual{Nothing, Float32, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:266
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:894
  Float32(::IrrationalConstants.Invsqrtπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
  ...

Stacktrace:
  [1] convert(::Type{Float32}, x::Dual{Nothing, Float32, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float32}, x::Dual{Nothing, Float32, 1}, i::Int64)
    @ Base ./array.jl:969
  [3] (::ChainRulesCore.ProjectTo{SparseMatrixCSC, @NamedTuple{…}})(dx::Matrix{Dual{…}})
    @ ChainRulesCoreSparseArraysExt ~/.julia/packages/ChainRulesCore/7MWx2/ext/ChainRulesCoreSparseArraysExt.jl:79
  [4] #1476
    @ ~/.julia/packages/ChainRules/snrkz/src/rulesets/Base/arraymath.jl:36 [inlined]
  [5] unthunk
    @ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:204 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:110 [inlined]

MWE which doesn't use Zygote:

julia> using ChainRulesCore, SparseArrays, ForwardDiff

julia> A = sprand(Float32, 5, 5, 0.5);

julia> ProjectTo(A)(ones(5, 5))
5×5 SparseMatrixCSC{Float32, Int64} with 14 stored entries:
     1.0  1.0        
     1.0            
 1.0      1.0      1.0
 1.0      1.0  1.0  1.0
 1.0      1.0  1.0  1.0

julia> ProjectTo(A)(ones(5, 5) .+ ForwardDiff.Dual(0,1))
ERROR: MethodError: no method matching Float32(::Dual{Nothing, Float64, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:266
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:894
  Float32(::IrrationalConstants.Invsqrtπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
  ...

Stacktrace:
 [1] convert(::Type{Float32}, x::Dual{Nothing, Float64, 1})
   @ Base ./number.jl:7
 [2] setindex!(A::Vector{Float32}, x::Dual{Nothing, Float64, 1}, i::Int64)
   @ Base ./array.jl:969
 [3] (::ProjectTo{SparseMatrixCSC, @NamedTuple{…}})(dx::Matrix{Dual{…}})
   @ ChainRulesCoreSparseArraysExt ~/.julia/packages/ChainRulesCore/7MWx2/ext/ChainRulesCoreSparseArraysExt.jl:79
 [4] top-level scope

Dense matrices do allow eltypes like Dual here, e.g. ProjectTo(Matrix(A))(ones(5, 5) .+ ForwardDiff.Dual(0,1)). This is needed for forward-over-reverse things, like Zygote.hessian.

The bug is in this line:

nzval = Vector{project_type(project.element)}(undef, length(project.rowval))

Note that the whole spares prediction story is basically placeholder code, rushed in to make 1.0 have the desired behaviour of preserving sparsity. It is quite slow, and could really use some care from someone who actually use it. (Maybe shipping 1.0 with deliberate errors would have been better.)

@mcabbott mcabbott transferred this issue from JuliaDiff/ChainRules.jl Dec 31, 2023
@mcabbott mcabbott added bug Something isn't working help wanted Extra attention is needed labels Dec 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants