Enzyme custom rules tutorial

More Examples

The tutorial below focuses on a simple setting to illustrate the basic concepts of writing custom rules. For more complex custom rules beyond the scope of this tutorial, you may take inspiration from the following in-the-wild examples:

The goal of this tutorial is to give a simple example of defining a custom rule with Enzyme. Specifically, our goal will be to write custom rules for the following function f:

function f(y, x)
    y .= x.^2
    return sum(y)
end
f (generic function with 1 method)

Our function f populates its first input y with the element-wise square of x. In addition, it returns sum(y) as output. What a sneaky function!

In this case, Enzyme can differentiate through f automatically. For example, using forward mode:

using Enzyme
x  = [3.0, 1.0]
dx = [1.0, 0.0]
y  = [0.0, 0.0]
dy = [0.0, 0.0]

g(y, x) = f(y, x)^2 # function to differentiate

@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when g is run
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
dy = [6.0, 0.0]

(See the AutoDiff API tutorial for more information on using autodiff.)

But there may be special cases where we need to write a custom rule to help Enzyme out. Let's see how to write a custom rule for f!

Don't use custom rules unnecessarily!

Enzyme can efficiently handle a wide range of constructs, and so a custom rule should only be required in certain special cases. For example, a function may make a foreign call that Enzyme cannot differentiate, or we may have higher-level mathematical knowledge that enables us to write a more efficient rule. Even in these cases, try to make your custom rule encapsulate the minimum possible construct that Enzyme cannot differentiate, rather than expanding the scope of the rule unnecessarily. For pedagogical purposes, we will disregard this principle here and go ahead and write a custom rule for f :)

Defining our first rule

First, we import the functions EnzymeRules.forward, EnzymeRules.augmented_primal, and EnzymeRules.reverse. We need to overload forward in order to define a custom forward rule, and we need to overload augmented_primal and reverse in order to define a custom reverse rule.

import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules

In this section, we write a simple forward rule to start out:

function forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, y::Duplicated, x::Duplicated)
    println("Using custom rule!")
    ret = func.val(y.val, x.val)
    y.dval .= 2 .* x.val .* x.dval
    return Duplicated(ret, sum(y.dval))
end
forward (generic function with 9 methods)

In the signature of our rule, we have made use of Enzyme's activity annotations. Let's break down each one:

  • the Const annotation on f indicates that we accept a function f that does not have a derivative component, which makes sense since f is not a closure with data that could be differentiated.
  • the Duplicated annotation given in the second argument annotates the return value of f. This means that our forward function should return an output of type Duplicated, containing the original output sum(y) and its derivative.
  • the Duplicated annotations for x and y mean that our forward function handles inputs x and y which have been marked as Duplicated. We should update their shadows with their derivative contributions.

In the logic of our forward function, we run the original function, populate y.dval (the shadow of y), and finally return a Duplicated for the output as promised. Let's see our rule in action! With the same setup as before:

x  = [3.0, 1.0]
dx = [1.0, 0.0]
y  = [0.0, 0.0]
dy = [0.0, 0.0]

g(y, x) = f(y, x)^2 # function to differentiate

@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when g is run
Using custom rule!
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
dy = [6.0, 0.0]

We see that our custom forward rule has been triggered and gives the same answer as before.

Handling more activities

Our custom rule applies for the specific set of activities that are annotated for f in the above autodiff call. However, Enzyme has a number of other annotations. Let us consider a particular example, where the output has a DuplicatedNoNeed annotation. This means we are only interested in its derivative, not its value. To squeeze out the last drop of performance, the below rule avoids computing the output of the original function and just computes its derivative.

function forward(func::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, y::Duplicated, x::Duplicated)
    println("Using custom rule with DuplicatedNoNeed output.")
    y.val .= x.val.^2
    y.dval .= 2 .* x.val .* x.dval
    return sum(y.dval)
end
forward (generic function with 10 methods)

Our rule is triggered, for example, when we call autodiff directly on f, as the return value's derivative isn't needed:

x  = [3.0, 1.0]
dx = [1.0, 0.0]
y  = [0.0, 0.0]
dy = [0.0, 0.0]

@show autodiff(Forward, f, Duplicated(y, dy), Duplicated(x, dx)) # derivative of f w.r.t. x[1]
@show dy; # derivative of y w.r.t. x[1] when f is run
Using custom rule with DuplicatedNoNeed output.
autodiff(Forward, f, Duplicated(y, dy), Duplicated(x, dx)) = (6.0,)
dy = [6.0, 0.0]
Custom rule dispatch

When multiple custom rules for a function are defined, the correct rule is chosen using Julia's multiple dispatch. In particular, it is important to understand that the custom rule does not determine the activities of the inputs and the return value: rather, Enzyme decides the activity annotations independently, and then dispatches to the custom rule handling the activities, if one exists. If a custom rule is specified for the correct function/argument types, but not the correct activity annotation, a runtime error will be thrown alerting the user to the missing activity rule rather than silently ignoring the rule."

Finally, it may be that either x, y, or the return value are marked as Const. We can in fact handle this case, along with the previous two cases, all together in a single rule:

Base.delete_method.(methods(forward, (Const{typeof(f)}, Vararg{Any}))) # delete our old rules

function forward(func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
                 y::Union{Const, Duplicated}, x::Union{Const, Duplicated})
    println("Using our general custom rule!")
    y.val .= x.val.^2
    if !(x isa Const) && !(y isa Const)
        y.dval .= 2 .* x.val .* x.dval
    elseif !(y isa Const)
        y.dval .= 0
    end
    dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val))
    if RT <: Const
        return sum(y.val)
    elseif RT <: DuplicatedNoNeed
        return dret
    else
        return Duplicated(sum(y.val), dret)
    end
end
forward (generic function with 9 methods)

Let's try out our rule:

x  = [3.0, 1.0]
dx = [1.0, 0.0]
y  = [0.0, 0.0]
dy = [0.0, 0.0]

g(y, x) = f(y, x)^2 # function to differentiate

@show autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) # derivative of g w.r.t. x[1]
@show autodiff(Forward, g, Const(y), Duplicated(x, dx)) # derivative of g w.r.t. x[1], with y annotated Const
@show autodiff(Forward, g, Const(y), Const(x)); # derivative of g w.r.t. x[1], with x and y annotated Const
Using our general custom rule!
autodiff(Forward, g, Duplicated(y, dy), Duplicated(x, dx)) = (120.0,)
Using our general custom rule!
autodiff(Forward, g, Const(y), Duplicated(x, dx)) = (0.0,)
autodiff(Forward, g, Const(y), Const(x)) = (0.0,)

Note that there are also exist batched duplicated annotations for forward mode, namely BatchDuplicated and BatchDuplicatedNoNeed, which are not covered in this tutorial.

Defining a reverse-mode rule

Let's look at how to write a simple reverse-mode rule! First, we write a method for EnzymeRules.augmented_primal:

function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active},
                          y::Duplicated, x::Duplicated)
    println("In custom augmented primal rule.")
    # Compute primal
    if needs_primal(config)
        primal = func.val(y.val, x.val)
    else
        y.val .= x.val.^2 # y still needs to be mutated even if primal not needed!
        primal = nothing
    end
    # Save x in tape if x will be overwritten
    if overwritten(config)[3]
        tape = copy(x.val)
    else
        tape = nothing
    end
    # Return an AugmentedReturn object with shadow = nothing
    return AugmentedReturn(primal, nothing, tape)
end
augmented_primal (generic function with 9 methods)

Let's unpack our signature for augmented_primal :

  • We accepted a EnzymeRules.Config object with a specified width of 1, which means that our rule does not support batched reverse mode.
  • We annotated f with Const as usual.
  • We dispatched on an Active annotation for the return value. This is a special annotation for scalar values, such as our return value, that indicates that that we care about the value's derivative but we need not explicitly allocate a mutable shadow since it is a scalar value.
  • We annotated x and y with Duplicated, similar to our first simple forward rule.

Now, let's unpack the body of our augmented_primal rule:

  • We checked if the config requires the primal. If not, we need not compute the return value, but we make sure to mutate y in all cases.
  • We checked if x could possibly be overwritten using the Overwritten attribute of EnzymeRules.Config. If so, we save the elements of x on the tape of the returned EnzymeRules.AugmentedReturn object.
  • We return a shadow of nothing since the return value is Active and hence does not need a shadow.

Now, we write a method for EnzymeRules.reverse:

function reverse(config::ConfigWidth{1}, func::Const{typeof(f)}, dret::Active, tape,
                 y::Duplicated, x::Duplicated)
    println("In custom reverse rule.")
    # retrieve x value, either from original x or from tape if x may have been overwritten.
    xval = overwritten(config)[3] ? tape : x.val
    # accumulate dret into x's shadow. don't assign!
    x.dval .+= 2 .* xval .* dret.val
    # also accumulate any derivative in y's shadow into x's shadow.
    x.dval .+= 2 .* xval .* y.dval
    y.dval .= 0
    return (nothing, nothing)
end
reverse (generic function with 9 methods)

Let's make a few observations about our reverse rule:

  • The activities used in the signature correspond to what we used for augmented_primal.
  • However, for Active return types such as in this case, we now receive an instance dret of Active for the return type, not just a type annotation, which stores the derivative value for ret (not the original return value!). For the other annotations (e.g. Duplicated), we still receive only the type. In that case, if necessary a reference to the shadow of the output should be placed on the tape in augmented_primal.
  • Using dret.val and y.dval, we accumulate the backpropagated derivatives for x into its shadow x.dval. Note that we have to accumulate from both y.dval and dret.val. This is because in reverse-mode AD we have to sum up the derivatives from all uses: if y was read after our function, we need to consider derivatives from that use as well.
  • We zero-out y's shadow. This is because y is overwritten within f, so there is no derivative w.r.t. to the y that was originally inputted.
  • Finally, since all derivatives are accumulated in place (in the shadows of the Duplicated arguments), these derivatives must not be communicated via the return value. Hence, we return (nothing, nothing). If, instead, one of our arguments was annotated as Active, we would have to provide its derivative at the corresponding index in the tuple returned.

Finally, let's see our reverse rule in action!

x  = [3.0, 1.0]
dx = [0.0, 0.0]
y  = [0.0, 0.0]
dy = [0.0, 0.0]

g(y, x) = f(y, x)^2

autodiff(Reverse, g, Duplicated(y, dy), Duplicated(x, dx))
@show dx # derivative of g w.r.t. x
@show dy; # derivative of g w.r.t. y
In custom augmented primal rule.
In custom reverse rule.
dx = [120.0, 40.0]
dy = [0.0, 0.0]

Let's also try a function which mutates x after running f, and also uses y directly rather than only ret after running f (but ultimately gives the same result as above):

function h(y, x)
    ret = f(y, x)
    x .= x.^2
    return ret * sum(y)
end

x  = [3.0, 1.0]
y  = [0.0, 0.0]
dx .= 0
dy .= 0

autodiff(Reverse, h, Duplicated(y, dy), Duplicated(x, dx))
@show dx # derivative of h w.r.t. x
@show dy; # derivative of h w.r.t. y
In custom augmented primal rule.
In custom reverse rule.
dx = [120.0, 40.0]
dy = [0.0, 0.0]

Marking functions inactive

If we want to tell Enzyme that the function call does not affect the differentiation result in any form (i.e. not by side effects or through its return values), we can simply use EnzymeRules.inactive. So long as there exists a matching dispatch to EnzymeRules.inactive, the function will be considered inactive. For example:

printhi() = println("Hi!")
EnzymeRules.inactive(::typeof(printhi), args...) = nothing

function k(x)
    printhi()
    return x^2
end

autodiff(Forward, k, Duplicated(2.0, 1.0))
(4.0,)

Or for a case where we incorrectly mark a function inactive:

double(x) = 2*x
EnzymeRules.inactive(::typeof(double), args...) = nothing

autodiff(Forward, x -> x + double(x), Duplicated(2.0, 1.0)) # mathematically should be 3.0, inactive rule causes it to be 1.0
(1.0,)

Testing our rules

We can test our rules using finite differences using EnzymeTestUtils.test_forward and EnzymeTestUtils.test_reverse.

using EnzymeTestUtils, Test

@testset "f rules" begin
    @testset "forward" begin
        @testset for RT in (Const, DuplicatedNoNeed, Duplicated),
            Tx in (Const, Duplicated),
            Ty in (Const, Duplicated)

            x = [3.0, 1.0]
            y = [0.0, 0.0]
            test_forward(g, RT, (x, Tx), (y, Ty))
        end
    end
    @testset "reverse" begin
        @testset for RT in (Active,),
            Tx in (Duplicated,),
            Ty in (Duplicated,),
            fun in (g, h)

            x = [3.0, 1.0]
            y = [0.0, 0.0]
            test_reverse(fun, RT, (x, Tx), (y, Ty))
        end
    end
end
Test.DefaultTestSet("f rules", Any[Test.DefaultTestSet("forward", Any[Test.DefaultTestSet("RT = Const, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 7, false, false, true, 1.712780655325656e9, 1.712780656877561e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780652723551e9, 1.712780656877568e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 7, false, false, true, 1.712780656878955e9, 1.712780658233064e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780656877633e9, 1.712780658233069e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 7, false, false, true, 1.712780658234411e9, 1.712780659403943e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780658233121e9, 1.712780659403948e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Const, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Const on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 7, false, false, true, 1.712780659405258e9, 1.712780660697564e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780659403995e9, 1.712780660697569e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.712780660698969e9, 1.712780662112844e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780660697619e9, 1.712780662112849e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.712780662113946e9, 1.712780664188224e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780662112896e9, 1.712780664188228e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 6, false, false, true, 1.712780664189324e9, 1.712780665565976e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.71278066418826e9, 1.712780665565979e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = DuplicatedNoNeed, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity DuplicatedNoNeed on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 6, false, false, true, 1.712780665567089e9, 1.71278066734282e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780665566009e9, 1.712780667342824e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Const, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Const), (::Vector{Float64}, Const)", Any[], 7, false, false, true, 1.712780667344083e9, 1.712780668339992e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780667342855e9, 1.712780668339996e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Const, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Const), (::Vector{Float64}, Duplicated)", Any[], 7, false, false, true, 1.712780668341073e9, 1.712780669414764e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780668340025e9, 1.712780669414768e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Duplicated, Ty = Const", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Const)", Any[], 7, false, false, true, 1.712780669416047e9, 1.712780670523138e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780669414797e9, 1.712780670523142e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Duplicated, Tx = Duplicated, Ty = Duplicated", Any[Test.DefaultTestSet("test_forward: g with return activity Duplicated on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 7, false, false, true, 1.712780670524396e9, 1.712780671710442e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_forward.jl")], 0, false, false, true, 1.712780670523169e9, 1.712780671710445e9, false, "custom_rule.md")], 0, false, false, true, 1.712780652723502e9, 1.712780671710446e9, false, "custom_rule.md"), Test.DefaultTestSet("reverse", Any[Test.DefaultTestSet("RT = Active, Tx = Duplicated, Ty = Duplicated, fun = g", Any[Test.DefaultTestSet("test_reverse: g with return activity Active on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 11, false, false, true, 1.712780671844095e9, 1.712780675412139e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_reverse.jl")], 0, false, false, true, 1.712780671710508e9, 1.712780675412144e9, false, "custom_rule.md"), Test.DefaultTestSet("RT = Active, Tx = Duplicated, Ty = Duplicated, fun = h", Any[Test.DefaultTestSet("test_reverse: h with return activity Active on (::Vector{Float64}, Duplicated), (::Vector{Float64}, Duplicated)", Any[], 11, false, false, true, 1.712780675413014e9, 1.712780679348778e9, false, "/home/runner/work/Enzyme.jl/Enzyme.jl/lib/EnzymeTestUtils/src/test_reverse.jl")], 0, false, false, true, 1.712780675412179e9, 1.712780679348783e9, false, "custom_rule.md")], 0, false, false, true, 1.712780671710467e9, 1.712780679348785e9, false, "custom_rule.md")], 0, false, false, true, 1.712780652723454e9, 1.712780679348785e9, false, "custom_rule.md")

In any package that implements Enzyme rules using EnzymeRules, it is recommended to add EnzymeTestUtils as a test dependency to test the rules.


This page was generated using Literate.jl.