Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

writing rules for <:AbstractArray #582

Open
maartenvd opened this issue Sep 2, 2022 · 1 comment
Open

writing rules for <:AbstractArray #582

maartenvd opened this issue Sep 2, 2022 · 1 comment
Labels
documentation Improvements or additions to documentation ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values

Comments

@maartenvd
Copy link

How should one write "proper" rules for methods that work for generic AbstractArray objects?

As an example, take this function:

function _setindex(a::AbstractArray,v,args...)
    b::typeof(a) = copy(a);
    b[args...] = v
    b
end

This method seems pretty tame, and I think should be generically correct for any abstractarray object. The backward rule looks simple:

function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...) 
    t = _setindex(a,tv,args...);
    
    function toret(v)
        backwards_tv = v[args...];
        backwards_a = copy(v);
        backwards_a[args...] = zero.(backwards_a[args...])
        (NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
    end
    t,toret
end

This doesn't work of course, v can be a zerotangent! Let's correct for this case:

function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...) 
    t = _setindex(a,tv,args...);
    
    function toret(v)
        if iszero(v)
            backwards_tv = ZeroTangent();
            backwards_a = ZeroTangent();
        else
            backwards_tv = v[args...];
            backwards_a = copy(v);
            backwards_a[args...] = zero.(backwards_a[args...])
        end
        (NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
    end
    t,toret
end

But this rule is still incorrect! When working with arrays, the tangent type can sometimes be a FillArray. FillArrays don't define setindex!, but they can be converted.

function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...) 
    t = _setindex(a,tv,args...);
    
    function toret(v)
        if iszero(v)
            backwards_tv = ZeroTangent();
            backwards_a = ZeroTangent();
        else
            v = convert(typeof(a),v);
            backwards_tv = v[args...];
            backwards_a = copy(v);
            backwards_a[args...] = zero.(backwards_a[args...])
        end
        (NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
    end
    t,toret
end

Still wrong of course, as it can also be a Tangent, which cannot be copied or converted, but they can be constructed!

function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...) 
    t = _setindex(a,tv,args...);
    
    function toret(v)
        if iszero(v)
            backwards_tv = ZeroTangent();
            backwards_a = ZeroTangent();
        else
            v = v isa Tangent ? construct(typeof(a),v) : v;
            v = convert(typeof(a),v);
            backwards_tv = v[args...];
            backwards_a = copy(v);
            backwards_a[args...] = zero.(backwards_a[args...])
        end
        (NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
    end
    t,toret
end

In short, my rrule essentially has to be a spaghetti of if statements, and at the end I will have no way of knowing whether my implementation will work in practice. There is no list of possible tangent types - or a formal interface that they should al satisfy, and so whatever operations I do may end up being undefined.

I have read the documentation, and I just don't understand how I am to write this backward rule. I also don't understand how I am to hook up my own types so that they play nice with chainrules.

This year old PR seems like a step in the right direction #446 but even that wouldn't solve the issue completely. ProjectTo is defined in such a way that - when faced with a type it doesn't know - it falls back to just returning the same Tangent type.

@mcabbott
Copy link
Member

mcabbott commented Sep 3, 2022

should be generically correct for any abstractarray object

I think it needs to ensure a mutable copy, more like b = copyto!(similar(a, T, axes(a)), a). This is poorly documented but the version with axes removes a lot of structure from e.g. UpperTriangular.

This will also solve the Fill problem.

v can be a zerotangent!

Or NoTangent. I actually thought this was supposed to be handled upstream (i.e. AD should notice and not call the pullback) but many rules have ended up with toret(v::AbstractZero) = ....

it can also be a Tangent,

Yes, a generic rule targetting AbstractArray can't really know what to do with a Tangent. But notice this: for it to get (say) a Tangent{RotationMatrix}(; θ = ...) back, the forward pass of this generic rule has to produce a RotationMatrix. It may get say a = RotationMatrix(pi/4), but to perform say _setindex(a, [10,20], :, 1) (writing numbers to one column) we cannot fit b into the same structure.

For any given weird array type, there may be some generic rules which can produce it. For instance the rule for * may produce another RotationMatrix. (Or a TriDiagonal.) It's these rules where we have to worry about getting back a Tangent, and #446 wants to standardise many back to AbstractMatrix.

The alternative is for a type to standardise on a Tangent. In which case it probably wants to use @opt_out on generic rules which may produce it, so that AD deals with its own implementation of say * which takes the type apart -- uses getfield etc, and these pieces are what the Tangent contains.

@mcabbott mcabbott added Structural Tangent Related to the `Tangent` type for structured (composite) values ProjectTo related to the projection functionality documentation Improvements or additions to documentation labels Sep 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values
Projects
None yet
Development

No branches or pull requests

2 participants