Why Autodiff?

We propose to add automatic differentiation to Rust. This would allow Rust users to compute derivatives of arbitrary functions, which is the essential enabling technology for differentiable programming. This feature would open new opportunities for Rust in scientific computing, machine learning, robotics, computer vision, probabilistic analysis, and other fields.

Case studies from autodiff developers/users

Jan Hückelheim (Argonne National Lab, US):

Automatic differentiation (AD, also known as autodiff or back-propagation) has been used at Argonne and other national laboratories, at least, since the 1980s. For example, we have used AD to obtain gradients of computational fluid dynamics applications for shape-optimization, which allows the automated design of aircraft wings or turbine blades to minimize drag or fuel consumption. AD is used extensively in many other applications including seismic imaging, climate modeling, quantum computing, or software verification.

Besides the aforementioned “conventional” uses of AD, it is also a cornerstone for the development of ML methods that incorporate physical models. The 2022 department of energy report on Advanced Research Directions on AI for Science, Energy, and Security states that “End-to-end differentiability for composing simulation and inference in a virtuous loop is required to integrate first-principles calculations and advanced AI training and inference”. It is therefore conceivable that AD usage and development will become even more important in the near future. 1

Prof. Jed Brown (CU Boulder, US):

My primary applications are in computational mechanics (constitutive modeling and calibration), where it'll enable us to give a far better user experience than commercial packages, but differentiable programming is a key enabler for a lot of scientific computing and ML research and production.

Background

What is autodiff used for?

Autodiff is widely used to evaluate gradients for numerical optimization, which is otherwise intractable for a large number of parameters. Indeed, suppose we have a scalar-valued loss function \(f(\theta)\) where the parameter vector \(\theta\) has length \(n\). If the cost to evaluate \(f(x)\) once is \(c\) (which will often be \(O(n)\)), then evaluating the gradient \(\partial f/\partial x\) costs less than \(3n\) with autodiff or tedious and brittle by-hand implementation, but \(cn\) otherwise. Optimization of systems of size \(n\) in the hundreds to billions are common in applications such as calibration, data assimilation, and design optimization of physical models, in perception and control systems for robotics, and machine learning.

Derivatives are also instrumental to thermodynamically admissible physical models, in which models are developed using non-observable free energy functionals and dissipation potentials, with observable dynamics represented by their derivatives. Commercial engineering software requires users to implement these derivatives by hand (e.g., Abaqus UHYPER and UMAT) and constitutive modeling papers routinely spend many pages detailing how to efficiently compute the necessary derivatives since these are among the most computationally intensive parts of simulation-based workflows and numerical stability is necessary.

Prior art

Autodiff emerged as an identified mathematical framework and software tool in the 1980s, building on groundwork from previous decades. The common implementation strategies were operator overloading and source transformation. In the former approach, one replaces scalars with a pair carrying the primary and variation and overloads elementary operations, leading to code similar to the following:

#![allow(unused)]
fn main() {
struct ActiveReal(f64, f64);

impl Add for ActiveReal {
    type Output = Self;
    fn add(self, rhs: Self) -> Self::Output {
        Self(self.0 + rhs.0, self.1 + rhs.1)
    }
}

impl Mul for ActiveReal {
    type Output = Self;
    fn mul(self, rhs: Self) -> Self {
        // Evoke the product rule, d(xy) = x * dy + dx * y
        Self(self.0 * rhs.0, self.0 * rhs.1 + self.1 * rhs.0)
    }
}
}

This approach can be distributed in a library (in languages that support operator overloading), but in practice, it means making most user functions generic and has serious performance impacts. For example, loops are effectively unrolled and this strategy is harmful to vectorization since "even and odd lanes" require different processing. Some libraries, especially in C++, have applied heavy template metaprogramming to make those libraries reasonably performant, but they still lag behind and have obtuse error messages.

The second approach, source transformation, was first applied to Fortran and later C, but never differentiated the complete language. Those systems implemented their own parsers and would generate new source files with mangled symbols and calling conventions, to be manually included by the build systems. Such approaches could be much more efficient by retaining loop structure and enabling vectorization, but the language coverage was always incomplete and they have historically been difficult to install (e.g., OpenAD depended on the research compiler ROSE, which could only be compiled with a specific version of gcc). Error handling and debugging is also poor with this approach.

Autodiff remained an active research area with relatively clumsy research-grade tooling for decades. Select scientific applications, such as MITgcm, would shape their entire software engineering around enabling source transformation tools to compute derivatives. The most recent machine learning boom was precipitated by autodiff engineering improving to the point where practitioners could rapidly design new neural network architectures and have efficient gradients automatically available for gradient-based optimization. The modern deep learning libraries (PyTorch, TensorFlow, JAX) have grown into more general autodiff tools, albeit with language restrictions and sharp edges due to custom JIT compilation separate from their host language (Python with C++ infrastructure).

Enzyme takes advantage of the fact that while differentiating all features of the source language is a monumental task, differentating a single-statement assignment (SSA) form such as LLVM IR is far more tractable and allows complete language coverage. Since it's based on LLVM, it is language-agnostic for all LLVM-based languages.

TODO: https://enzyme.mit.edu/conference/assets/JanHueckelheim_EnzymeCon2023.pdf [paper]

TODO: Talk about Torch, TensorFlow, PyTorch, JAX, Julia, etc.

Installation

This work is in the process of upstreaming into nightly Rust. In the future a manual installation should not be neccessary anymore (unless you want to contribute to this project). Till then, you can follow the instructions below. Please be aware that the msvc target is not supported at the moment, all other tier 1 targets should work. Please open an issue if you have issues on a supported tier 1 target, or if you succesfully build this project on a tier2/tier3 target.

Build instructions

First you need to clone and configure this Rust fork:

git clone --depth=1 git@github.com:EnzymeAD/rust.git
cd rust
./configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs

Afterwards you can build rustc using:

./x.py build --stage 1 library

Afterwards rustc toolchain link will allow you to use it through cargo:

rustup toolchain link enzyme build/host/stage1
rustup toolchain install nightly # enables -Z unstable-options

You can then run examples from our docs:

cd ..
git clone git@github.com:EnzymeAD/rustbook.git  
cd rustbook/samples
cargo +enzyme test reverse

If you want to use Autodiff in your own projects, you will need to add lto="fat" to your Cargo.toml and use cargo +enzyme instead of cargo or cargo +nightly.

Usage

Enzyme differentiates arbitrary multivariate vector functions as the most general case in automatic differentiation

\[ f: \mathbb{R}^n \rightarrow \mathbb{R}^m, y = f(x) \]

For simplicity we define a vector function with \(m=1\). However, this tutorial can easily be applied to arbitrary \(m \in \mathbb{N}\).

#![allow(unused)]
fn main() {
fn f(x: &[f32], y: &mut f32) {
    *y = x[0] * x[0] + x[1] * x[0];
}
}

We also support functions that return a float value:

#![allow(unused)]
fn main() {
fn g(x: &[f32]) -> f32 {
    x[0] * x[0] + x[1] * x[0]
}
}

Forward Mode

The forward model is defined as \[ \begin{aligned} y &= f(x) \\ \dot{y} &= \nabla f(x) \cdot \dot{x} \end{aligned} \]

To obtain the first element of the gradient using the forward model we have to seed \(\dot{x}\) with \(\dot{x}=[1.0,0.0]\).

In the forward mode the second element which gets added for Dual arguments stores the tangent.

    #[autodiff(df, Forward, Dual, Dual)]
    fn f(x: &[f32; 2], y: &mut f32) {
        *y = x[0] * x[0] + x[1] * x[0];
    }

    fn main() {
        let x = [2.0, 3.0];
        let dx = [1.0, 0.0];
        let mut y = 0.0;
        let mut dy = 0.0;
        df(&x, &dx, &mut y, &mut dy);
        assert_eq!(10.0, y);
        assert_eq!(7.0, dy);
    }

In the returning case we would write similar code, note that in this case the second Dual refers to our return value.

    #[autodiff(df, Forward, Dual, Dual)]
    fn f(x: &[f32; 2]) -> f32 { x[0] * x[0] + x[1] * x[0] }

    fn main() {
        let x  = [2.0, 2.0];
        let dx = [1.0, 0.0];
        let (y, dy) = df(&x, &dx);
        assert_eq!(dy, 2.0 * x[0] + x[1]);
        assert_eq!(y, f(&x));
    }

Note that to acquire the full gradient one needs to execute the forward model a second time with the seed dx set to [0.0, 1.0].

Reverse Mode

The reverse model in AD is defined as \[ \begin{aligned} y &= f(x) \\ \bar{x} &= \bar{y} \cdot \nabla f(x) \end{aligned} \] where the bar denotes an adjoint variable. Note that executing AD in reverse mode takes \(x, \bar y\) as input and computes \(y, \bar x\) as output.

Enzyme stores the value and adjoint of a variable when marking a type as Duplicated. Then the first element represent the value and the second the adjoint. Evaluating the reverse model using Enzyme is done in the following example.

    #[autodiff(df, Reverse, Duplicated, Duplicated)]
    fn f(x: &[f32; 2], y: &mut f32) {
        *y = x[0] * x[0] + x[1] * x[0];
    }

    fn main() {
        let x = [2.0, 3.0];
        let mut bx = [0.0, 0.0];
        let mut y = 0.0;
        let mut by = 1.0; // seed
        df(&x, &mut bx, &mut y, &mut by);
        assert_eq!([7.0, 2.0], bx);
        assert_eq!(10.0, y);
        assert_eq!(0.0, by); // seed is zeroed
    }

This yields the gradient of f in bx at point x = [2.0, 2.0]. by is called the seed and has to be set to 1.0 in order to compute the gradient. Please note that unlike Dual, for Duplicated the seed is getting zeroed, which is required for correctness in certain cases. Note that the output bx is initialized to zero on input, and df accumulates into it.

We can again also handle functions returning a scalar. In this case we mark the return value as duplicated. The seed is then going to be an extra, last input argument.

    #[autodiff(df, Reverse, Duplicated, Active)]
    fn f(x: &[f32; 2]) -> f32 {
        x[0] * x[0] + x[1] * x[0]
    }

    fn main() {
        let x = [2.0, 3.0];
        let mut bx = [0.0, 0.0];
        let by = 1.0; // seed
        let y = df(&x, &mut bx, by);
        assert_eq!([7.0, 2.0], bx);
        assert_eq!(10.0, y);
    }

We can now verify that indeed the reverse mode and forward mode yield the same result.

    #[autodiff(df_fwd, Forward, Dual, Dual)]
    #[autodiff(df_rev, Reverse, Duplicated, Duplicated)]
    fn f(x: &[f32; 2], y: &mut f32) {
        *y = x[0] * x[0] + x[1] * x[0];
    }

    fn main() {
        let x = [2.0, 3.0];

        // Compute gradient via forward-mode
        let dx_0 = [1.0, 0.0];
        let dx_1 = [0.0, 1.0];
        let mut y = 0.0;
        let mut dy_f = [0.0, 0.0];
        df_fwd(&x, &dx_0, &mut y, &mut dy_f[0]);
        df_fwd(&x, &dx_1, &mut y, &mut dy_f[1]);
        assert_eq!([7.0, 2.0], dy_f);

        // Compute gradient via reverse-mode
        let mut bx = [0.0, 0.0];
        let mut y = 0.0;
        let mut by = 1.0; // seed
        df_rev(&x, &mut bx, &mut y, &mut by);
        assert_eq!([7.0, 2.0], bx);
        assert_eq!(10.0, y);
        assert_eq!(0.0, by); // seed is zeroed
    }

As we can see, the number of calls under Forward mode scales with the number of input values. Reverse mode scales with the number of output parameters, and is therefore preferable if we have less outputs than inputs. A common example is the training of neural networks, where we have a single output (loss), but a large input (weights).

Forward Mode

Forward mode (often also called Dual mode) is generally recommended if the output dimension is greater than the active input dimension, or if the dimension is similar. Dimension here refers to the total number of scalar values in all input (output, respectively) arguments.

We annotate input and output arguments either with Const, Dual, or DualOnly.

In Forward mode we are only allowed to mark input arguments with Dual or Const. The return value of forward mode with a Dual return is a tuple containing as the first value the primal return value and as the second value the derivative.

In forward mode Dual with two arguments x, 0.0 is equivalent to Const passing only x, except that we can perform more optimizations for Const.

Reverse Mode

Autodiff on LLVM IR

TODO: Typical LICM \(O(n)\) vs \(O(n^2)\) Enzyme example. TODO: Talk about what makes this approach special and a good fit for Rust conceptually.

Changes to Rust

TODO: Talk about the new attributes and define the semantics of these new attributes. Give examples.

Reverse Mode

Both the inplace and "normal" variant return the gradient. The difference is that with Active the gradient is returned and with Duplicated the gradient is accumulated in-place.

Usage story

Let us start by looking at the most basic examples we can think of:

\[ f(x,y) = x^2 + 3y \]

We have two input variables \(x\), \(y\) and a scalar return value. The gradient is

\[ \nabla f = \Big[\frac{\partial f}{\partial x}, \frac{\partial f}{\partial y} \Big] = \big[2x, 3 \big] \]

Let's check for Enzyme (our compiler explorer does not handle Rust yet, so you'll have to trust me on this).

    #[autodiff(df, Reverse, Active, Active, Active)]
    fn f(x: f32, y: f32) -> f32 {
        x * x + 3.0 * y
    }

    fn main() {
        let (x, y) = (5.0, 7.0);
        let (z, bx, by) = df(x, y, 1.0);
        assert_eq!(46.0, z);
        assert_eq!(10.0, bx);
        assert_eq!(3.0, by);
    }

Enzyme actually generates the code on LLVM-IR level, but Rust is nicer to read, so I will pretend we would generate a Rust implementation:

fn f(x: f32, y: f32) -> f32 {
  x * x + 3.0 * y
}
fn df(x: f32, y: f32) -> (f32, f32, f32) {
  let d_dx = 2.0 * x;
  let d_dy = 3.0;
  let f = x * x + 3.0 * y;
  (d_dx, d_dy, f)
}

Note that the last entry in the result tuple contains the original return value. However, we don't always pass things by value, so let's make sure we have a sensible solution:

#[autodiff(df, Reverse, Active, Duplicated, Active)]
fn f(x: f32, y: &f32) -> f32 {
  x * x + 3.0 * y
}

(pay attention to y).

fn f(x: f32, y: f32) -> f32 {
  x * x + 3.0 * y
}
fn df(x: f32, y: &f32, d_dy: &mut f32) -> (f32, f32) {
  let d_dx = 2.0 * x;
  *d_dy += 3.0;
  let f = x * x + 3.0 * y
  (d_dx, f)
}

In the case of references (or pointers) we do expect the user to create d_dy.

We could obviously zero-initialize a float for the user, but let's assume the constructor is complex due to involving a double-linked-list or ffi, so we can't guarantee that on the compiler side. As an alternative motivation, imagine that we call df 5 times in a loop. It is clear that in this case the accumulated gradients should be 5 times higher too, which won't happen if we set d_dy = 3.0 each time, instead of using +=. Yet another reason would be higher-order derivatives (todo: just refer to literature?).

Now that we got back from this rabbit hole, let's go wild and train a neural network on our local national lab server:

#[autodiff(backprop, Reverse, Duplicated, Duplicated, Active)]
fn training_loss(images: &[f32], weights: &[f32]) -> f32 {
  let loss = do_some_math(images, weights);
  loss
}

Now Enzyme gives us:

fn training_loss(images: &[f32], weights: &[f32]) -> f32 {
  let loss = do_some_math(images, weights);
  loss
}
fn backprop(images: &[f32], dimages: &mut [f32], weights: &[f32], dweights: &mut [f32]) -> f32 {
  enzyme_update_inplace_dx(dimages);
  enzyme_update_inplace_dy(dweights);
  let loss = do_some_math(images, weights);
  loss
}

Uuuuhm. Yeah? We want to minimize our loss, so let's do weights -= learning_rate * dweights;

We also just learned how we can update our images through dimages, but unless you know how to shape the world around you that's pretty useless, so we just wasted a good amount of our compute time for not needed gradients. Let's try again:

#[autodiff(backprop, Reverse, Constant, Duplicated, Active)]
fn training_loss(images: &[f32], weights: &[f32]) -> f32 {
  let loss = do_some_math(images, weights);
  loss
}

After all, we shouldn't modify our train and test images to improve our accuracy, right? So we now generate:

fn training_loss(images: &[f32], weights: &[f32]) -> f32 {
  let loss = do_some_math(images, weights);
  loss
}
fn backprop(images: &[f32], weights: &[f32], dweights: &mut [f32]) {
  enzyme_update_inplace_dy(dweights);
  let loss = do_some_math(x,y);
  loss
}

Great. No more random dimages that we don't know how to handle. Perfection? Almost:

#[autodiff(backprop, Reverse, Constant, Duplicated, DuplicatedNoNeed)]
fn training_loss(images: &[f32], weights: &[f32]) -> f32 {
  let loss = do_some_math(images, weights);
  loss
}

Happy to accept better names than DuplicatedNoNeed. Either way, now we have:

fn training_loss(images: &[f32], weights: &[f32]) -> f32 {
  let loss = do_some_math(images, weights);
  loss
}
fn backprop(images: &[f32], weights: &[f32], dweights: &mut [f32]) {
  enzyme_update_inplace_dy(dweights);
}

We run backprop to get the gradients to update our weights, tracking of the loss while training is optional. Keep in mind that this will allow Enzyme to do some slightly advanced dead code elimination, but at the end of the day Enzyme will still need to compute most of do_some_math(x, y) in order to calculate dy. So how much runtime you save by not asking for loss will depend on your application. We won't introduce a new motivation for our last example, but let's assume we have reasons to only want dweights, but do not care about the original weights anymore.

#[autodiff(backprop, Reverse, Constant, DuplicatedNoNeed, DuplicatedNoNeed)]
fn training_loss(images: &[f32], weights: &[f32]) -> f32 {
  let loss = do_some_math(images, weights);
  loss
}

DuplicatedNoNeed allows Enzyme to reuse the memory of our weigths variable as a scratchspace. That means it might increase the performance, but in exchange the variable shall not be assumed to have meaningful values afterwards. That's obviously only valid in Julia, C++, etc., but not in Rust. We had some discussion on whether this can be represented as MaybeUninit or Option but didn't got to a conclusion yet. (WIP)

fn training_loss(images: &[f32], weights: &[f32]) -> f32 {
  let loss = do_some_math(images, weights);
  loss
}
fn backprop(images: &[f32], weights: &[f32], dweights: &mut [f32]) {
  enzyme_update_inplace_dy(dweights);
}

And as the very last one, Enzyme follows Jax and all the other AD tools by allowing batched backpropagation:

#[autodiff(backprop, Reverse(2), Constant, Duplicated, DuplicatedNoNeed)]
fn training_loss(images: &[f32], weights: &[f32]) -> f32 {
  let loss = do_some_math(images, weights);
  loss
}

We don't expose batchmode on the Rust side yet, let's do one step after the other.

fn training_loss(images: &[f32], weights: &[f32]) -> f32 {
  let loss = do_some_math(images, weights);
  loss
}
fn backprop(images: (&[f32], &[f32]), weights: (&[f32], &[f32]), dweights: (&mut f[f32], &mut [f32])) {
  enzyme_update_inplace_dy(dweights.0);
  enzyme_update_inplace_dy(dweights.1);
}

Higher Order Derivatives

Computing higher order derivatives like hessians can be done with Enzyme by differentiating functions that compute lower order derivatives. This requires that functions are differentiated in the right order, which we currently don't handle. As a workaround, we introduce two new AD modes ForwardFirst and ReverseFirst that will be differentiated (and optimized) before we differentiate the default Forward and Reverse mode invocations. An example is given below.

// A direct translation of
// https://enzyme.mit.edu/index.fcgi/julia/stable/generated/autodiff/#Forward-over-reverse

#[autodiff(ddf, Forward, Dual, Dual, Dual, Dual)]
fn df2(x: &[f32;2], dx: &mut [f32;2], out: &mut [f32;1], dout: &mut [f32;1]) {
    df(x, dx, out, dout);
}

#[autodiff(df, ReverseFirst, Duplicated, Duplicated)]
fn f(x: &[f32;2], y: &mut [f32;1]) {
    y[0] = x[0] * x[0] + x[1] * x[0]
}

#[test]
fn main() {
    let mut y = [0.0];
    let x = [2.0, 2.0];

    let mut dy = [0.0];
    let mut dx = [1.0, 0.0];

    let mut bx = [0.0, 0.0];
    let mut by = [1.0];
    let mut dbx = [0.0, 0.0];
    let mut dby = [0.0];

    ddf(&x, &mut bx, &mut dx, &mut dbx, 
        &mut y, &mut by, &mut dy, &mut dby);
}

Current limitations

Safety and Soundness

Enzyme currently assumes that the user passes shadow arguments (dx, dy, ...) of appropriate size. Under Reverse Mode, we additionally assume that shadow arguments are mutable. In both modes we insert automatically checks to verify that Dual/Duplicated slices have shadow arguments of the right size. In Reverse Mode we also adjust the outermost pointer or reference to be mutable. Therefore &f32 will receive the shadow type &mut f32. However, we do not check lenght for other types than slices (e.g. enums, Vec). We also do not enforce mutability of inner references, but will warn if we recognize them. We do intend to add additional checks over time.

ABI adjustments

In some cases, a function parameter might get lowered in a way that we currently don't handle correctly, leading to a compile time type missmatch in the rustc_codegen_llvm backend. Here are some examples.

Compile Times

Enzyme will often achieve excellent runtime performance, but might increase your compile time by a large factor. For Rust, we already have made significant improvements and have a list of further improvements planed - please reach out if you have time to help here.

Type Analysis

Most of the times, Type Analysis (TA) is the reason of large (>5x) compile time increases when using Enzyme. This poster explains why we need to run Type Analysis in the bottom left part: Poster Link. We intend to increase the number of locations where we pass down Type information based on Rust types, which in turn will reduce the number of locations where Enzyme has to run Type Analysis, which will help compile times.

Duplicated Optimizations

The key reason for Enzyme offering often excellent performance is that Enzyme differentiates already optimized LLVM-IR. However, we also (have to) run LLVM's optimization pipeline after differentiating, to make sure that the code which Enzyme generates is optimized properly. As a result you should have excellent runtime performance (please fill an issue if not), but at a compile time cost for running optimizations twice.

FAT-LTO

The usage of #[autodiff(...)] currently requires compiling your project with fat-lto. We technically only need lto if the function being differentiated calls functions in other compilation units. Therefore other solutions are possible but this is the most simple one to get started. The compile time overhead of lto is small compared to the current compile time overhead of differentiating larger functions so this limitation is currently not a priority.

Future Work

Parallelism:

Enzyme supports the ability to efficiently differentiate parallel code. Enzyme's unique ability to combine optimization (including parallel optimization) enables orders of magnitude improvements on performance and scaling parallel code. Each parallel framework needs only provide Enzyme lightweight markers describing where the parallelism is created (e.g. this is a parallel for or spawn/sync). Such markers have been added for various parallel paradigms, including: CUDA, ROCm, OpenMP, MPI, Julia tasks, and RAJA.

Such markers have not been added for Rust parallel libraries (i.e. rayon). Enzyme only does need to support the lowest level of parallelism for each language, so adding support for rayon should cover most cases. We assume 20-200 lines of code in Enzyme core should be sufficient, making it a nice task to get started.
rsmpi (Rust wrapper for MPI) should already work, but it would be good to test.

Batching

Batching allows computing multiple derivatives at once. This can help amortizing the cost of the forward pass. It can also be used to enable future vectorization. This feature is quite popular for ML projects. The JAX documentation gives an example here. Batching is supported by Enzyme core and can be trivially implemented for Rust-Enzyme in a few hours, the main blocker is bikesheding are around the frontend. Do we want to accept N individual shadow arguments? Do we want to accept a tuple of N elements? An array [T;N]?

Custom Derivatives

Let's assume that you want to use differentiable rendering, but someone added a "fast" version of the inverse square root function to your render engine, breaking your Rust-Enzyme tool, which can't figure out how i = 0x5f3759df - ( i >> 1 ); would affect your gradient. AutoDiff packages for this reason allow declaring a custom derivative f' for a function f. In such a case the AD tool will not look at the implementation of f and directly use the user provided f'. Jax documentation also has a large list of other reasons due to which you might want to use custom derivatives: link. Julia has a whole ecosystem called ChainRules.jl around custom derivatives. Enzyme does support custom derivatives, but we do not expose this feature on the Rust side yet. Together with the Batching features, this is one of the highest rewards / lowest effort improvements planed for Rust-Enzyme.

Custom Allocators:

Enzyme does support custom allocators, but Rust-Enzyme does not expose support for it yet. Please let us know if you have an application that can benefit from a custom allocator and autodiff, otherwise this likely won't be implemented in the forseeable future.

Checkpointing:

While Enzyme is very fast due to running optimizations before AD, including various partial checkpointing algorithms -- such as a min-cut algorithm. The ability to control checkpointing (e.g. whether to recompute or store) has not yet been added to Rust. Optimal checkpointing generally lies in NP to find the optimal balance for each given program, but there are good approximations. You can think of it in terms of custom allocators. Replacing the algorithm might affect your runtime performance, but does not affect the result of your function calls. In the future it might be interesting to let the user interact with checkpointing.

Supporting other Codegen backends:

Enzyme consists of ~50k LoC. Most of the rules around generating derivatives for instructions are written in LLVM Tablegen.td declarations and as such it should be relatively easy to port them. Enzyme also includes various experimental features which we don't need on the Rust side, an implementation for another codegen backend could therefore also end up a bit smaller. The cranelift backend would also benefit from ABI compability, which makes it very easy to test correctness of a new autodiff tool against Enzyme. Our modifications to rustc_codegen_ssa and previous layers of rustc are written in a generic way, s.t. no changes would be needed there to enable support for additional backends.

GPU / TPU / IPU / ... support.

Enzyme supports differentiating CUDA/ROCm Kernels. There are various ways towards exposing this capabilities to Rust. Manuel and Jed will be experimenting with two different approaches in 2024, and there is also a lot of simultaneous research. Please reach out if you are also working on GPU programming in Rust.

MLIR support:

Enzyme partly supports multiple MLIR dialects. MLIR can offer great runtime performance benefits for certain workloads. It would be nice to have a rustc_codegen_mlir, but there is a very large number of open questions around the design.

History and ecosystem

Enzyme started as a project created by William Moses and Valentin Churavy to differentiate the LLVM-IR, including languages with an LLVM frontends like C, Julia, Swift, Fortran, etc. Operating within the compiler enables Enzyme to interoperate with optimizations, allowing for higher performance than conventional methods while simultaneously not needing special handling for each language and construct. Enzyme is an LLVM Incubator projects and intends to ask for upstreaming later in 2024.

In 2020, initial investigations on using Enzyme on Rust was led by Tiberius Ferreria and William Moses through the use of foreign function calls (https://internals.rust-lang.org/t/automatic-differentiation-differential-programming-via-llvm/13188/7).

In 2021, Manuel Drehwald and Lorenz Schmidt worked on Oxide-Enzyme which aimed to directly integrate Enzyme as a compiler-aware cargo plugin.

The current Rust-Enzyme project direct embeds Enzyme into rust and makes available autodiff macros for easy usage. The project is led by Manuel Drehwald, in collaboration with Jed Brown, William Moses, Lorenz Schmidt, Ningning Xie, and Rodrigo Vargas-Hernandez.

Development of a Rust-Enzyme frontend

We hope that as part of the nightly releases Rust-Enzyme can mature relatively fast because:

  1. Unlike Julia, Rust does not emit code involving Garbage Collection, JIT, or Type Unstable code -- simplifying the inputs to Enzyme (and reducing the need to develop support for such mechanisms, which have since been added to Enzyme.jl).
  2. Unlike Clang, we do ship the source code for the standard library. On the Rust side, we therefore don't need to manually add support for functions libstdc++ like std::map decrement.
  3. Minimizing Rust code is reasonably nice and Cargo/crates.io makes it easy to reproduce bugs.

Non-alternatives

The key aspect for the performance of our solution is that AD is performed after compiler optimizations have been applied (and is able to run additional optimizations). This observation is mostly language independent and motivated in the 2020 Enzyme Neurips paper, and also mentioned towards the end of this non-Enzyme java autodiff case-study.

Wrapping cargo instead of modifying rustc

We can use Enzyme without modifying rustc, as demonstrated in oxide-enzyme.

  1. We let users specify a list of functions which they want to differentiate and how (forward/reverse, activities...). example.
  2. We manually emit the optimized llmv-ir of our rust programm and all dependencies.
  3. We llvm-link all files into a single module (equivalent to fat-lto).
  4. We call Enzyme to differentiate functions.
  5. We adjust linker visibility of the new functions and create an archive that exports those new functions.
  6. We termintate this cargo invocation (can e.g. be achieved by -Zno-link).
  7. We call cargo a second time, this time providing our archive as additional linker argument. The functions provided by the archive exactly match the extern fn declarations created through our macro here.

This PoC required the use of build-std, to be able to see the llvm-ir of functions from the std lib.
An alternative would have been to provide rules for Enzyme on how to differentiate every function from the Rust std, which seems undesirable. It would however not be impossible, C++-Enzyme has various rules for the C++ std lib.

This approach also assumes that linking llvm-ir generated by two different cargo invocations and passing Rust objects between those works fine.

This approach is further limited in compile times and reliability. See the example at the bottom left of this poster. LLVM types are often too limited to determine the correct derivative (e.g. opaque ptr), and as such Enzyme has to run a usage analysis to determine the relevant type of a variable. This can be time consuming (we encountered multiple cases with > 1000x longer compile times) and it can be unreliable, if Enzyme fails to deduce the correct type of a variable due to insufficient usages. When calling Enzyme from within rustc, we are able to provide high-level type information to Enzyme. For oxide-enzyme, we tried to mitigate this by using a Dwarf debug parser (requirering debug information even in release builds), but even with this helpers we were completely unable to support Enums due to their ability of representing different types. This approach was also limited since rustc (at the time we wrote it) did not emit Dwarf information for all Rust types with unstable layout.

Rust level autodiff

Various Rust libraries for the training of Neural Networks exist (burn/candle/dfdx/rai/autograph). We talked with developers from burn, rai, and autograph to compare the autodiff performance under the Microsoft ADBench Benchmark suite. After some investigation all three decided that supporting such cases would require significant redesigns of their projects, which they can't afford in the forseeable future.
When training Neural Networks, we often look at few large variables (tensors) and a small set of functions (layers) which dominate the runtime. Using these properties it's possible to amortize some inefficiencies by getting the most expensive operations efficient. Such optimizations stop working, once we look at the larger set of applications for scientific computing or HPC.

How to Debug AD?

Since Rust-AD is still in early development, crashes are not unlikely.

Frontend crashes

If you see a proper Rust stacktrace after a compilation failure, our frontend (thus rustc) has likely crashed. It should often be trivial to create a minimal reproducer, by deleting most of the body of the function being differentiated, or by replacing the function body with a loop {} statement. Please create an issue with such a reproducer, it will likely be easy to fix!

For the unexpected case, that you produce an ICE in our frontend that is harder to minimize, please consider using icemelter.

Backend crashes

If you see llvm-ir (a language which might remind you of assembly), then our backend crahed. You can find instructions on how to create an issue and help us to fix it on the next page.

Debuging and Profiling

Rust-AD supports passing an autodiff flag to RUSTFLAGS, which supports changing the behaviour of Enzyme in various ways. Documentation is availabile here.

Reporting backend crashes

If after a compilation failure you are greeted by a large amount of LLVM-IR code, then our Enzyme backend likely failed to compile your code. These cases are harder to debug, so your help is highly appreciated. Please also keep in mind, that release builds are usually much more likely to work at the moment.

The final goal here is to reproduce your bug in the Enzyme compiler explorer, in order to create a bug report in the Enzyme core repository.

We have an autodiff flag which you can pass to RUSTFLAGS to help with this. It will print the whole LLVM-IR module, along with dummy functions called enzyme_opt_dbg_helper_<i>. A potential workflow on Linux could look like:

1) Generate an LLVM-IR reproducer

RUSTFLAGS="-Z autodiff=OPT" cargo +enzyme build --release &> out.ll 

This also captures a few warnings and info messages above and below your module. Open out.ll and remove every line above ; ModuleID = <SomeHash>. Now look at the end of the file and remove everything that's not part of LLVM-IR, i.e. remove errors and warnings. The last line of your LLVM-IR should now start with !<someNumber> = , i.e. !40831 = !{i32 0, i32 1037508, i32 1037538, i32 1037559} or !43760 = !DILocation(line: 297, column: 5, scope: !43746). The actual numbers will depend on your code.

2) Check your LLVM-IR reproducer

To confirm that you're previous step worked, let's will use LLVM's opt tool. Find your path to the opt binary, with a path similar to <some_dir>/rust/build/<x86/arm/...-target-tripple>/build/bin/opt. Also find LLVMEnzyme-19.<so/dll/dylib> path, similar to /rust/build/target-tripple/enzyme/build/Enzyme/LLVMEnzyme-19. Once you have both, run the following command:

<path/to/opt out.ll> -load-pass-plugin=/path/to/LLVMEnzyme-19.so -passes="enzyme" -S 

If the previous step succeeded, you are going to see the same error that you saw when compiling your Rust code with Cargo. If you fail to get the same error, please open an issue in the Rust repository. If you succeed, congrats! The file is still huge, so let's automatically minimize it.

3) Minimize your LLVM-IR reproducer

First find your llvm-extract binary, it's in the same folder as your opt binary. Then run:

<path/to/llvm-extract> -S --func=<name> --recursive --rfunc="enzyme_opt_helper_*" out.ll -o mwe.ll 

Please adjust the name passed with the --func flag. You can either apply the #[no_mangle] attribute to the function you differentiate, then you can replace it with the Rust name. Otherwise you will need to look up the mangled function name. To do that open out.ll and search for __enzyme_fwddiff or __enzyme_autodiff. The first string in that function call is the name of your function. Example:

define double @enzyme_opt_helper_0(ptr %0, i64 %1, double %2) {
  %4 = call double (...) @__enzyme_fwddiff(ptr @_ZN2ad3_f217h3b3b1800bd39fde3E, metadata !"enzyme_const", ptr %0, metadata !"enzyme_const", i64 %1, metadata !"enzyme_dup", double %2, double %2)
  ret double %4
}

Here, _ZN2ad3_f217h3b3b1800bd39fde3E is the correct name. Make sure to not copy the leading @. Redo step 2), but now pass mwe.ll instead of out.ll to mod, to see if your minimized example reproduces your crash.

4) (Optional) Minimize your LLVM-IR reproducer further.

After the previous step you should have an mwe.ll file with ~5k LoC. Let's try to get it down to 50. Find your llvm-reduce binary next to opt and llvm-extract. Copy the first line of your error message, an example could be:

opt: /home/manuel/prog/rust/src/llvm-project/llvm/lib/IR/Instructions.cpp:686: void llvm::CallInst::init(llvm::FunctionType*, llvm::Value*, llvm::ArrayRef<llvm::Value*>, llvm::ArrayRef<llvm::OperandBundleDefT<llvm::Value*> >, const llvm::Twine&): Assertion `(Args.size() == FTy->getNumParams() || (FTy->isVarArg() && Args.size() > FTy->getNumParams())) && "Calling a function with bad signature!"' failed.

If you just get a segfault there is no sensible error message and not much to do automatically, so continue to 5).
Otherwise, create a script.sh file containing

#!/bin/bash
<path/to/your/opt> $1 -load-pass-plugin=/path/to/LLVMEnzyme-19.so -passes="enzyme" \
    |& grep "/some/path.cpp:686: void llvm::CallInst::init"

Experiment a bit with which error message you pass to grep. It should be long enough to make sure that the error is unique. However, for longer errors including ( or ) you will need to escape them correctly which can become annoying. Run

<path/to/llvm-reduce> --test=script.sh mwe.ll 

If you see Input isn't interesting! Verify interesting-ness test, you got the error message in script.sh wrong, you need to make sure that grep matches your actuall error. If all works out, you will see a lot of iterations, ending with a new reduced.ll file. Verify with opt that you still get the same error.

5) Report your bug.

Afterwards, you should be able to copy and paste your mwe.ll (and reduced.ll) example into our compiler explorer.
Select LLVM IR as language and opt 20 as compiler. Replace the field to the right of your compiler with -passes="enzyme", if it is not already set. Hopefully, you will see once again your now familiar error. Please use the share button to copy links to them.

Please create an issue on https://github.com/EnzymeAD/Enzyme/issues and share mwe.ll and (if you have it) reduced.ll, as well as links to the compiler explorer. Please feel free to also add your Rust code or a link to it. With that, hopefully someone from the Enzyme core repository will be able to fix your bug. Once that happened, I will update the Enzyme submodule inside the rust compiler, which should allow you to now differentiate your Rust code. Thanks for helping us to improve Rust-AD.

Minimize Rust code

Beyond having a minimal LLVM-IR reproducer, it is also helpful to have a minimal Rust reproducer without dependencies. This allows us to add it as a testcase to CI once we fix it, which avoids regressions for the future.

There are a few solutions to help you with minimizing the Rust reproducer. This is probably the most simple automated approach: cargo-minimize

Otherwise we have various alternatives, including treereduce, halfempty, or picireny

Potentially also creduce

Supported RUSTFLAGS

To support you while debugging, we have added support for an experimental -Z autodiff flag to RUSTFLAGS, which allow changing the behaviour of Enzyme, without recompiling rustc. We currently support the following values for autodiff:

Debug Flags

PrintTA // Print TypeAnalysis information
PrintAA // Print ActivityAnalysis information
PrintPerf // Print AD related Performance warnings
Print // Print all of the above
PrintModBefore // Print the whole LLVM-IR module before running opts
PrintModAfterOpts // Print the whole LLVM-IR module after running opts, before AD
PrintModAfterEnzyme // Print the whole LLVM-IR module after running opts and AD
LooseTypes // Risk incorect derivatives instead of aborting when missing Type Info 
OPT // Most Important debug helper: Print a Module that can run with llvm-opt + enzyme

LooseTypes is often helpful to get rid of Enzyme errors stating Can not deduce type of <X> and to be able to run some code. But please keep in mind that this flag absolutely has the chance to cause incorrect gradients. Even worse, the gradients might be correct for certain input values, but not for others. So please create issues about such bugs and only use this flag temporarily while you wait for your bug to be fixed.

Benchmark flags

For performance experiments and benchmarking we also support

NoModOptAfter // We won't optimize the whole LLVM-IR Module after AD
EnableFncOpt // We will optimize each derivative function generated individually
NoVecUnroll // Disables vectorization and loop unrolling
NoSafetyChecks // Disables Enzyme specific safety checks
RuntimeActivity // Enables the runtime activity feature from Enzyme 
Inline // Instructs Enzyme to apply additional inlining beyond LLVM's default
AltPipeline // Don't optimize IR before AD, but optimize the whole module twice after AD

You can combine multiple autodiff values using a comma as separator:

RUSTFLAGS="-Z autodiff=LooseTypes,NoVecUnroll" cargo +enzyme build

The normal compilation pipeline of Rust-Enzyme is

  1. Run your selected compilation pipeline. If you selected a release build, we will disable vectorization and loop unrolling.
  2. Differentiate your functions.
  3. Run your selected compilation pipeline again on the whole module. This time we do not disable vectorization or loop unrolling.

The alt pipeline will not run opts before AD, but 2x after AD - the first time without vectorization or loop unrolling, the second time with.

The two flags above allow you to adjust this default behaviour.

Other Enzyme frontends

Enzyme currently has experimental frontends for C/C++, Julia, Fortran, Numba, some MLIR dialects, and Rust.

General LLVM/MLIR, as well as C/C++/CUDA documentation is available at https://enzyme.mit.edu
Julia documentation is available at https://enzyme.mit.edu/julia
Rust documentation is available at https://enzyme.mit.edu/rust
Enzyme-JAX (including HLO MLIR AD) is available at https://github.com/EnzymeAD/Enzyme-JAX. Enzyme has been demonstrated on various other languages including Swift, and Fortran, but no frontend has been developed to improve ease of use and installation for these languages.

We have a compiler-explorer fork with support for autodiff in C/C++/CUDA, Julia, and MLIR here.

Developer documentation is available at https://enzyme.mit.edu/doxygen
Please reach out if you would like to see support for additional languages.

Forward Mode

When using forward mode, we only have three choices of activity values, Dual, DualOnly and Const. Dual arguments get a second "shadow" variable. Usually we will only seed the shadow variable of one Dual input to one and all others to zero, and then read the shadow values of our output arguments. We can also seed more then one input shadow, in which case the shadow of output variables will be a linear combination based on the seed values. If we use a &mut reference as input and output argument and mark it as Dual, the corresponding shadow seed might get overwritten. Otherwise, the seed value will remain unchanged.

ActivityDualDualOnlyConst
Non integer input TAccept T,TAccept byVal(T), TUnchanged
Integer scalar inputN/AN/AUnchanged
f32 or f64 output TReturn (T,T)Return TUnchanged
Other output typesN/AN/AUnchanged

DualOnly is a potentially faster version of Dual.

When applied to a return type, it will cause the primal return to not be computed. So in the case of fn f(x: f32) -> f32 { x * x }, we would now only return 2.0 * x, instead of (x * x, 2.0 * x), which we would get with Dual.

In the case of an input variable, DualOnly will cause the first value to be passed by Value, even when passed by Reference in the original function. So fn f(x: &f32, out: &mut f32) {..} would become fn df(x: f32, dx &mut f32, out: f32, dout: &mut f32) {..}.
This makes x and out inaccessible for the user, so we can use it as buffer and potentially skip certain computations. This is mostly valuable for larger Types, or more complex functions.

Reverse Mode

Reverse Mode reverts the normal control flow, so in this case we will have to seed output variables. This can be a bit counter-intuitive at first, so we provide an additional helper for float arguments.

When marking a float input as Active, we will internally treat it as having a seed of one. So there will be no extra input argument. We will then return the resulting float, either directly, or if your primal function already has a return value, we will update the new function to return a touple, with one additional float per Active input. Those additional float values are in the same order as your Active input float values. If you instead want to seed a float output, you can still mark it as Active. We will then append one float argument to the list of input arguments, which will work as your seed.

More advanced types which are passed by pointer or reference can be marked as Duplicated. In this case a shadow argument will be added. Please note that unlike forward mode, the shadow here will always be mutable, therefore *mut or &mut. This is necessary since we have to overwrite (zero) your seed values for correctness. A motivation for this is given below.

ActivityActiveDuplicatedConst
f32 or f64 InputReturn additional floatN/AUnchanged
&T or &mut T InputN/AAdd &mut T shadowUnchanged
Other Input TypesN/AN/AUnchanged
f32 or f64 OutputAppend float to inputsN/AUnchanged
Other Output TypesN/AN/AUnchanged

Similar to Forward Mode, we offer optimized versions of our Activity Types.

ActivityActiveOnlyDuplicatedOnly
f32 or f64 InputN/AN/A
&T or &mut T InputN/AAccept T, &mut T
f32 or f64 OutputOmit primal return value
Append float to inputs
N/A

ActiveOnly can not yield any optimization for input values, since it is only applicable to floats passed by value. When used on a return type, it has the same effect as DualOnly in Forward Mode. So in the case of fn f(x: f32) -> f32 { x * x }, we would now only return 2.0 * x, instead of (x * x, 2.0 * x), which we would get with Active.

DuplicatedOnly has the same effect as DualOnly on input Arguments. The original input will now be taken by Value, allowing to use it as a scratch space. If the original Type was a mutable reference, we can additionally save the computation of any updates to that value.

Unlike Forward Mode, Reverse Mode will always overwrite (zero) your seed! If you want to call a function generated through Reverse Mode multiple times, you will have to reset your seed values!

Design:

There are naturally three locations where we could introduce our autodiff macro: on the definition side of f, on the definition side of df, or on the call-site of df.
We only have implemented support for the first two, because this is enough to implement the third one with a normal rust macro. Let us look at some example:

#![feature(autodiff)]
#[autodiff(df, Reverse, Active, Active)]
fn fsquare(x: f32) -> f32 {
  x * x
}

fn main() {
  let x = 3.14;
  let res = fsquare(x);
  let (res_, dres) = df(x, 1.0);
  // let dres = autodiff!(f, x);
  assert_eq!(dres, 2.0 * x);
  assert_eq!(res, res_);
}

Some tools always compute all gradients, or rely on Dead-Code-Elimination to remove code that would compute gradients unused by users. In comparison, Enzyme does only generate the code needed to compute your gradients from the beginning (and uses and extends various LLVM passes to do so). Therefore it is crucial that users (can) specify only the minimal set of variables as active. Library authors (e.g. faer, minmax, ...) can not reliably predict which gradients users will need, and offerint all would cause a combinatorical explosion. We therefore allow users to differentiate functions specified in a dependency. While not technically required, we do enforce that this function in the dependency is marked as public, to not violate Rusts visibility behaviour. We believe for now that this is the best compromise for usability, but would be happy to hear feedback. An example:

#![feature(autodiff)]
// ; dependency.rs
fn f(x: f32) -> f32 {
  x * x
}

// ; main.rs
#[autodiff(f, Reverse, Active, Active)]
fn df(x: f32, d_df: f32) -> (f32, f32);

fn main() {
  let x = 3.14;
  let res = f(x);
  let (_, dres) = df(x, 1.0);
  assert_eq!(dres, 2.0 * x);
}

The ; is our way of ensuring that users can't provide an implementation to it that would later get overwritten by the autodiff tool. We desugare it into a combination of unimplemented!(), inline-asm (noop), and bench-blackbox, to avoid the method being inlined. We take some additional care of this within the compiler.

Let's look at the configuration now:

  1. The first argument of our macro is the name (or path) to the "other" function. That is, if we apply our macro to f, then we would put the name of df, which the macro will then generate. If we put our macro on top of an empty function, then the name should be the one of the function which we want to differentiate.
  2. Forward vs Reverse mode: From a math point of view, this one describes in which order we will expand the chain rule, or rather if we want to calculate the vector-jacobian product, or the jacobian-vector product. It has a very large impact on the performance and the following config parameters will differ slightly based on the mode. No real impact beside of performance. Technically something in between Forward and Reverse would be possible, but has not been explored much by Jax, Enzyme, or others.
  3. Activity of arguments, in the order in which they appear in f. If we use ReverseMode and return a non () value, we specify activity for the return too. We are open to discuss the actual names here:Const, Active, Duplicated, DuplicatedNoNeed.

rustc Design:

This chapter is not relevant for an autodiff user, but might still be interesting for those curious to include Enzyme into a new language. It is mostly tailored towards people already working on rustc. I would like to claim it is also there to help reviewers, but realistically I just write things down here because I can't remember what I coded yesterday. This is likely incomplete, so please create PR's to extend it!

The first step was to integrate Enzyme (core) itself. We have added it as a submodule to src/tools/enzyme and updated the bootstraping accordingly. We had added our autodiff macro in /library, but this approach had to be abandonned for cross-compilation. The current macro is implemented as a rustc_internal_macro.

We had to alter the compilation pipeline in a few places when autodiff is used. The obvious one is that we have to prevent our source function from getting competely inlined.

compiler/rustc_ast/src/expand/autodiff_attrs.rs This is a new file, containing the logic to parse our autodiff macro into rustc builtin rustc_autodiff attributes. This blocks users from tinkering with our internals, by having to go through the macro. It has the nice side-effect, that we can implement an erroring dummy fallback if we see that the llvm_enzyme config hasn't been set when building rustc.

compiler/rustc_codegen_llvm/src/attributes.rs: In from_fn_attrs we query cx.tcx.autodiff_attrs(instance.def_id()); and if we get an autodiff is active (that is a source or a placeholder). This is to make sure that neither of the too gets inlined, for that we mark them as InlineAttr::Never and pray for the best.

compiler/rustc_codegen_ssa/src/codegen_attrs.rs We added autodiff_attrs, in which we parse rustc_autodiff attributes applied to Functions, and create AutoDiffAttrs out of it, which we return.

compiler/rustc_monomorphize/src/partitioning.rs In place_mono_items we check if the characteristic_def_id of a function exists, and if yes (and if it has an autodiff attrs) we block it from being inserted into internalization_candidates to prevent inlining.

compiler/rustc_monomorphize/src/partitioning.rs In collect_and_partition_mono_items we update things.

compiler/rustc_codegen_ssa/src/back/write.rs In generate_lto_work we pass the autodiff worklist to our backend autodiff function doing the actual work. We check there that our autodiff worklist is empty if we don't use fat-lto.

compiler/rustc_middle/src/ty/mod.rs How to create a typetree or fnctree from a rustc_middle::ty

Unsafe Interface

Especially Reverse Mode AD has various code transformations which can easily cause UB, when not used correctly. We work on catching all cases through our design and additional safety checks, but are aware that for some cases like hot loops, or GPU Kernels such checks are not desired.

A motivational (C++) example is given here This example would be instant UB in Rust because Types must never be in an invalid state, which is not the case for loss here. In our safe interface we solve this by accepting Types by value, making the invalid state inaccessible from Rust code. This however requires allocating a new variable for each call, which could become a performance limitation.

While not implemented yet, we propose a second interface, with the bikeshedding name unsafe_ad instead of autodiff. It will not generate the safety checks which we discussed in previous sections. The generated Interface also differs for DualOnly and DuplicatedOnly. See this examples:

fn f(x: &[f32], y: &mut f32) {
    y = x[0] * x[0] + x[1] * x[0];
}

#[autodiff(df, Forward, Dual, Dual)]
fn f(x: &[f32], y: &mut f32) { ... }
// fn df(x: &[f32], dx: &mut [f32], y: &mut f32, dy: &mut f32);

fn main() {
    let x  = [2.0, 2.0];
    let dx = [1.0, 0.0];
    let y  = 0.0
    let dy = 0.0;
    df(&x, &mut dx, &mut y, &mut dy);
}

The first optimization here is using DualOnly. Now

#![allow(unused)]
fn main() {
#[autodiff(df, Forward, DualOnly, DualOnly)]
fn f(x: &[f32], y: &mut f32) { ... }
// fn df(x: Vec<f32>, dx: &mut[f32], y: f32, dy: &mut f32);
}

Both x and y become inaccessible, so no harm can happen. But let us assume that we have to reuse these expensive memory allocations.

#![allow(unused)]
fn main() {
#[unsafe_ad(df, Forward, DualOnly, DualOnly)]
fn f(x: &[f32], y: &mut f32) { ... }
// unsafe fn df(x: MaybeUninit<&[f32]>, dx: &mut[f32], y: MaybeUninit<&mut f32>, dy: &mut f32);
}

We expect both x and y to be in a valid state, however they will be in an invalid state after calling df. This allows us to omit computing their original output value and it also allows us to re-use it's memory location as buffer memory.

Acknowledgments