Facilitating the Spread of Knowledge and Innovation in Professional Software Development

Write for InfoQ


Choose your language

InfoQ Homepage Presentations Differentiable Programming in Kotlin

Differentiable Programming in Kotlin



Irene Dea discusses how Facebook is using Kotlin, developing a new differentiable programming framework for it.


Irene Dea is a Software Engineer at Facebook.

About the conference

QCon Plus is a virtual conference for senior software engineers and architects that covers the trends, best practices, and solutions leveraged by the world's most innovative software organizations.


Dea: My name is Irene Dea. I'll be giving a talk on the new differentiable programming framework we're developing at Facebook. What is differentiable programming? Differentiable programming brings gradient based optimization techniques from machine learning into general programming. With machine learning, many algorithms are now being learned instead of explicitly written by programmers. In the words of Yann LeCun, it's really very much like a regular program, except it's parameterized, automatically differentiated, and trainable or optimizable.


Of course, there are many popular frameworks that do provide differentiation, such as PyTorch, TensorFlow, and JAX. However, these frameworks are heavily geared towards traditional machine learning models. There are many use cases that are not well supported, such as computer graphics, physics simulations, and probabilistic programming.

What We Need

Let's take a look at what we need to cover these other use cases. First, we'll need a fast language to write computationally intensive code wherever it is needed. Most popular frameworks today use Python as a surface language, which makes it inefficient to write custom logic. Of course, we'll need automatic differentiation so users can obtain the gradients of their computations automatically. Next, we'll need memory safety, as it can be difficult and frustrating to track down memory issues when you have other goals in mind. Type safety can also be very helpful, especially when writing large programs. Finally, static compilation enables many useful optimizations before runtime. With that in mind, in the end, it all comes down to performance, usability, and flexibility. It's important that users are able to write performant code that can be productionized easily and works well with mixed workload applications. On the usability side, users should be able to effectively develop and iterate on new models. Debuggability, for example, is a huge necessity. Bugs should be quickly surfaced and easily understood. With regards to flexibility, our framework can be applied to a wide range of use cases. Most frameworks work exceptionally well on certain use cases like traditional machine learning models. Once you step outside those boundaries, and are writing custom code, you quickly lose out on usability and performance.

Our Approach

To address these three major needs, we have taken a compiler-aware approach. We provide a customizable, extensible API that serves as the base toolkit for differentiability and building ML models. On top of that, we provide compile time optimizations, and compile time shape checking enabled through compiler plugins.

Why Kotlin?

Why Kotlin? We decided on Kotlin because it's easy to use, it's memory safe, it's type safe, and their functional constructs. It's extensible via compiler plugins, and it's performant. Kotlin is a JVM language, which means that it can be anywhere from half the speed of C++ to sometimes faster than C++. Of course, orders of magnitude faster than Python. Kotlin is developed by JetBrains, who is also the creator of the popular IDE, IntelliJ. As a result, Kotlin has excellent, seamless IDE support, which is crucial for developers. You might be wondering, why are we using an Android language? Actually, there is a talk happening, called, "Kotlin Is Way More Than Just Android." Kotlin is the official Android language, can be used for a wide variety of other applications, like server-side applications, writing DSLs. JetBrains even has a dedicated team of developers that are working on Kotlin for data science. For example, they're working on Notebook support. Furthermore, Kotlin's interoperability with Java gives it access to the entire Java ecosystem, which includes numerous popular libraries. In addition to the JVM, Kotlin can also target native LLVM, or JavaScript. With these other two backends, Kotlin can be used not only just for Android, and of course all the other applications as we discussed, but also for iOS and web programming. Kotlin is a well-loved language with an active developer community, and millions of users. We firmly believe that Kotlin is a perfect language for compiler-aware differentiability.

API: Derivatives and Extensibility

Let's take a look at our API. Our API is pure and functional, which means that the values we work with are immutable. This gives us performance and usability advantages. Our API supports both scalar and tensor math. For example, in this slide, we have a function f that's equivalent to sine. To compute the first derivative, you parse the value at which you're evaluating the derivative x, and a reference to the function you're taking the derivative of. In this example, the function fp, which is the derivative of f, is approximately equal to cosine. We can support both forward and reverse differentiation. We can see here that by nesting, we can also compute higher order derivatives. We also support the derivatives of multivariate functions, functions taking or returning user defined types, or tensor functions. That's like computing the Jacobian or the Hessian.

We also support many other components practically needed for AI applications. Notably, our API provides the necessary tools to tackle traditional ML problems, such as layers and optimizers. In addition, the API is designed to be customizable and extensible. For example, you can differentiate with inputs or outputs being user defined types. You can also add user defined trainable layers and components. Finally, our API is designed to be optimizable by a compiler plugin.

API: Performance

Now that we've covered the different pieces of the API, let's talk about performance. We've talked previously about how Kotlin is performant. This is exciting because it allows differentiability to be used in mixed workload applications. If you're writing a performance critical application, you can add pieces of differentiation without the cost of having to write logic in Python, as you would in other frameworks. For more traditional ML use cases, we've hooked into MKL-DNN, the go-to library for machine learning ops on CPU, and continue to work on C++ speedups. We have demonstrated performance on par with other popular frameworks. We also support sparse tensors. Sparse information is everywhere. For example, social network companies like Facebook have a lot of graph type problems, so they have a lot of sparse data. Many frameworks support some sparse tensor operations, but have little to no support for the gradients, which rents learning on sparse weights. So far, we have seen an order of magnitude performance increase with our sparse tensors. We're also developing a couple of compile time optimizations that work with our library. Our API is specifically designed for this purpose.

AD Optimize Plugin

Now that we've covered our API, let's take a look at some of the cool compiler optimizations that we're developing. Let's take a look at the first optimization, AD Optimize. Our AD implementation produces a compute tree for evaluating the derivative, which is built at runtime with a node created for each operation. This approach comes at the cost of extra allocations and function calls. This is a common problem in automatic differentiation frameworks. Our AD Optimize plugin addresses this cost by inlining differentiable computations, and unboxing scalars.

Example - Geometric Series (r=1/2)

Let's take a look at an example. Here you can see we have a function foo that computes the geometric series with r is equal to one-half. We can see here that y is the sum, and in this loop, we are cumulating a divided by 2 to the power of i. To compute the derivative, we create a compute tree. Here is what the compute tree looks like for three iterations. Each operation that we do produces a node in our compute tree. The loop of foo is actually running 1000 times. In each iteration, we have a couple of operations, add and divide, so we're actually creating approximately a couple 1000 node objects here at runtime in order to compute the derivative. How can we do better? With the AD Optimize plugin, we can unbox scalars, and inline derivative values in order to drastically reduce the number of objects that we create. For this example, this results in the creation of just one single object for the derivative computation. We can do even more.

Coarsening Optimization: Concept

Let's take a look at the next optimization, coarsening. Coarsening is a novel optimization technique that we have developed. There are two main ways to do differentiations. There's algorithmic differentiation, and there's symbolic differentiation. Algorithmic differentiation differentiates every operation at runtime. It has the finest granularity of operation. Symbolic differentiation applies calculus on the entire computation and has the largest granularity of operation. Coarsening introduces a new way to do AD by striking a balance between two existing methods and getting the best of both worlds.

Coarsening Optimization: Workflow

With coarsening, the AD optimizer takes the primal code as input, identifies segments of interest through a reuse-aware algorithm. Raises them to the symbolic level, and conducts symbolic differentiation, and generates the optimized code with mixed algorithmic symbolic differentiation. The key point here is that coarsening has a larger view than our standard differentiation, which allows for more optimizations.

Coarsening: Geometric Series (r=1/2)

Let's revisit the geometric series function. With coarsening, we're able to consider the entire functions computation, and recognize certain patterns. Coarsening can transform loops into summations, and even simplify summations further. Here, we can use this property of geometric series to simplify the foo function. We can then take the derivative of this simplified function. Here are the functions generated by our coarsening optimization. Here is the primal computation, and over here is the gradient computation. We have arrived at the primal computation through this property here, the one that looks like a screenshot from a textbook. The gradient was computed symbolically from this. We've effectively reduced our function foo to a simple short expression with no loops, and same for the gradient. Also, notice that when we obtained the derivative now, we only call fooGrad. Previously, the AD system would actually call the primal function foo, behind the scenes. Now we don't have any calls to foo at all. We can just call foo directly.

Time Reduction on Hookean Spring

Let's take a look at a more complex example. This slide shows the performance of Hookean Spring, a physics simulation program. It simulates mass spring systems. The three configurations correspond to three sizes of the spring system, in terms of the number of spring vertices. Here we have 10 vertices, 20 vertices, and 40 vertices. Coarsening is able to do symbolic differentiation on the entire computation of the gradient. As a result, the primal computation, which computes the system energy can be completely removed. You can see here that the primal time is actually reduced to zero. This is similar to what we saw on the previous example, where we only had to call fooGrad to get the derivative. The speedups here that we've observed are 4x to 11x. The program even runs faster than the original primal computation alone. We've also evaluated coarsening on other examples and have observed speedups of one to two orders of magnitude.

Static Shape Checking

Now that we've gone over compiler optimizations for performance, let's see how we're using compiler plugins to enhance usability. Tensors are often fed through many different operations. Each operation often has different shape requirements and produces a new tensor with a possibly different shape. The combination of shape requirements and new output shapes makes it incredibly easy to hit runtime shape errors with popular frameworks, which offer no static shape checking or information. Debugging runtime shape errors is hard, and a lot of users rely heavily on printing those shapes at runtime to debug errors or even to just understand what their code is doing. To address this issue, we're developing a compiler plugin for static shape checking. With this plugin, users will get not only compile time shape inference and shape checking, but also real-time feedback in IntelliJ, such as error messaging and redlining. With IntelliJ, users can also inspect the shapes of their tensors as they develop, so before they even build or run their code. The plugin is integrated with our API, which means that you can get static shape checking out of the box for numerous tensor operations. Lastly, the plugin functionality is extensible. There are a lot of tensor operations that have complex shape transformations. It's important that users can define their own shape functions.

Example: Static Shape Checking

Here's an example of how you can provide static shape checking for the matmul function, which implements a matrix multiply. Matmul takes two two-dimensional tensors and requires that the inner two dimensions match. Here we can see that the second dimension of x is the parameter B, and the first dimension of y is the parameter B. Now let's take a look at matmul in action. The shape of a, is inferred to be 1, 2, and the shape of b is inferred to be 2, 3. The value res is obtained from a correct usage of matmul that produces a tensor of shape 1, 3. Now, badRes shows an incorrect usage of matmul. This will result in a compile time error, as the inner dimensions do not match.

Static Shape Checking: IntelliJ

Here we have a couple of examples showing what the plugin looks like in IntelliJ. Up top, here, we have shape inspection. Over here on the bottom, we have an error message. This is all happening during development. This means that users can inspect their code and see errors immediately as they're typing.

More Complex Shape Checking

Let's take a look at an example of more complex shape checking. The shape of the matmul function was quite simple and could be done using positional matching. However, broadcast is an example of a shape transformation that's not so easily expressed. Broadcast is a very dynamic shape transformation that lets users add tensors of different shapes. Up top, we have the broadcast shape function. Users can write this code imperatively and even call other functions. On the bottom, we have the definition of add. Notice that the return shape is a call to the shape function broadcast on A and B. This type of extensibility allows users to find custom shape checking logic for new tensor operations.

Use Case: Probabilistic Programming

Now that we've gone over all the major components of our framework, let's take a look at an interesting use case. Bean Machine is a probabilistic programming system for Bayesian models that's being developed at Facebook. Traditional differentiation and machine learning frameworks were not a good fit for them. In particular, traditional frameworks generally lack higher order differentiation, performant scalar support, sparse tensor support, and fast execution of native language. For example, Newtonian Monte Carlo is a probabilistic inference algorithm that uses second order differentiation, and is well suited for scalars. With our framework, we're able to provide support for these use cases. We're collaborating with Bean Machine to provide the differentiability infrastructure that they need.


I started by saying that what we're doing really comes down to performance, usability, and flexibility. Here's how all the pieces we've discussed fit into that idea. For performance, we saw the benefits from sparse tensors, MKL-DNN, being in Kotlin, and our optimization plugins. With regards to usability, we talked about our functional API, and our static shape checking plugin. For flexibility, we saw that our API was extensible and customizable. We saw how the shape checking plugin was extensible as well. We also touched on our collaboration with Bean Machine, a probabilistic programming system.

Future Work

What do we have planned for the future? On the optimization side, we want to develop new optimizations enabled by compiler plugins. For example, there might be room for more domain specific optimizations, and because these are all plugins, users can pick and choose which ones to apply depending on what their goals are. On the usability front, there is a lot of other metadata out there, besides tensor shapes that could be used to help users to detect bugs early and ergonomically. Lastly, we're excited to see users utilize our framework for their innovative purposes, and we're looking forward to getting their feedback and working with them.

Questions and Answers

Schuster: It's always good to learn about all the new research going on in this area. It definitely gave me some flashbacks to calculus class, which is my downside, but it's useful. This is a library in Kotlin. How do you achieve the analysis of the code? You're talking about compiler plugins, is it like analyzing an interference, sort of interspersing itself into the compiler build chains, is that how it works?

Dea: Yes, exactly. We have compiler plugins, and they run at different levels in the plugin. For example, the static shape checking plugin, it runs in the frontend, and one of the things that we're intercepting is call resolution. That is a big thing, because we have to check each call, and we have to make sure that each call produces the right shape. Then for things like the AD Optimize plugin, which is generating code, that actually runs further down the compiler pipeline, so that one will run on the Kotlin IR, which is right before it gets translated to either JVM bytecode or native or JavaScript.

Schuster: That's basically a feature of Kotlin. You're not doing anything special here, it's just you plug in here, compiler generates new builds.

Dea: Yes, we're just using the general Kotlin plugin framework.

Schuster: How many language features do you support? Can I use everything in Kotlin? Can everything be differentiated?

Dea: It depends. Not everything can be differentiated, but we provide support in our API for our users to add their own custom differentiable objects. You would have to extend an interface and fit onto our API if you want your custom object to be differentiated. In terms of things like control flow, we are planning to handle that with compiler plugins as well.

Schuster: That's easy for users to do, they don't have to write their own plugins.

Dea: Yes, we are planning to provide the general features of Kotlin for users, and so they can differentiate generally through most features, and that includes control flow as well.

Schuster: What does it mean to differentiate over control flow?

Dea: To differentiate over control flow? To back up, if you are in a framework like PyTorch and you have like an IF statement, or a For loop, you're actually just tracing through the For loop. We talked about building a compute tree, and so what's happening is every operation that you do in the For loop gets unrolled into this compute tree. That's what's happening in a framework like PyTorch. For us, that is also happening, but what people actually prefer is to have the control flow preserved, so the dynamism is actually preserved later on in our models. In other frameworks, it just gets unrolled completely, but you lose the sense that you have control flow and you also are not able to do optimizations on it because you lose that.

Schuster: If you optimize the control flow away, you can debug into it, I suppose, is that right?

Dea: Yes, exactly.

Schuster: What do you use this for? Is it for writing general ML models just in a more language based way? Is that it?

Dea: Yes. The other popular frameworks are more suited towards that use case. The use cases that we're looking at more are outside of that box. For example, we did talk about probabilistic programming. One thing that people tend to use probabilistic programming for is things like climate models. There's a probabilistic programming model that's used to find gamers at the same level, when they pit them against each other. That's an example. Another example that you might see differentiable programming used is things like a ray tracer, and physics simulations. Things in that area, as well.

Schuster: For the ray tracing example, it's used to optimize the ray tracer or to learn anything about it?

Dea: Yes, so for the ray tracing example, I think it's generally used as an optimization. It's something you learn with respect to the different, like your objects might have a property to them, and it might have a position and so you're learning with respect to those things.

Schuster: You brought up the word, probabilistic programming. Is it too much to ask, is it a big explanation to give an overview of what that is?

Dea: Probabilistic programming is basically a way to encode uncertainty into your models. What happens in probabilistic programming is you provide some assumption about your world, and that might be in the form of distributions. You provide some statistical assumptions. Then you provide some observations on top of that. You're basically saying, here's how the world is acting. Then you say, this is what I'm seeing. Then what comes out is a distribution that says, this is what might actually be happening, given those two things, the observations and your assumptions. It allows people to add a level of uncertainty, but also some domain knowledge into their models.

Schuster: I think Facebook does a lot of research into that. I think we had a QCon talk two years ago from your colleague. If you would like to look at that more, just look for InfoQ, Probabilistic Programming. There's a talk online.

Dea: I'm not sure if I quite understand the question here, because we would need to have continuous functions in order to take the derivative of them. If you are writing something that you're taking the derivative of, it would need to be a continuous function.

Schuster: You gave the example of differentiating the sine function. How fancy can I get with functions in there? How much do you support, if I have like Bernoulli, something like that?

Dea: You can get pretty fancy. We do have a suite of examples. Also, generally, we do support anything traditional machine learning supports as well. Those models do get pretty fancy, too.

Schuster: Is this open source? Is this online? Are you going to open source it? Can people play around with this somewhere?

Dea: We're working on open sourcing our framework. That is the plan. Right now we're currently getting feedback internally, and we have some internal users. We don't have any concrete plans to open source soon.

Schuster: Do you work with Erik Meijer directly?

Dea: Erik Meijer actually started this team. We worked together in the beginning.

Schuster: I saw some tweets over the last year or two with him doing stuff with math in Kotlin. Everyone was like, what's he up to? How many monads does he force you to use?

Dea: Working with Erik is quite interesting. It is also quite interesting to definitely work in the intersection of machine learning and PLs, because I think there is so much overlap there. There's so much unexplored overlap there. For example, with the static shape checking, I think there's a lot of interesting overlap with the PLs world there. Because originally, we actually intended for that to be in the type checker, but we can't extend Kotlin in that direction. We actually started hacking the compiler at first. Our first prototype was actually just in the type checker. It's interesting to see how PLs, and PL and ML fit together here, especially with that specific feature.

Schuster: It's a tradeoff of having the neat, clean compiler plugins versus having your own branch of Kotlin, which makes it harder to use.

Dea: We went in that direction, because it was hard to maintain and make sure that everything was up to date upstream as well. The compiler plugins are really nice as well, and we're working with JetBrains on the usability.

Schuster: Can you maybe put your work in context of the wider research ecosystem. I think there was, unfortunately, a project called Swift for TensorFlow, which I think is related to what you're doing in a way, but for the Swift language. Are there others that we might want to watch while we wait for you to open source this?

Dea: There's other ones. Of course, there's PyTorch, TensorFlow, and JAX, which are very popular. Other than that, there's also Julia, which has Zygote. They have differentiation in Julia as well.

Schuster: Is Julia as extensible as Kotlin? Do they use macros or something?

Dea: I know that it's not statically compiled. In terms of types, it's between static and dynamic typing. It's not statically compiled, so there wouldn't be like a compiler plugin infrastructure.

Schuster: If you're looking at something like PyTorch, or Keras, are those the same way where you can write models in a language and then figure out the graphs. Is that right, or is it a different thing?

Dea: Keras is a layer or a machine learning library on top of TensorFlow. It's a high level library. That is a little bit different from PyTorch.

Schuster: I was just wondering if there's other tools that are similar to the work that you're doing with allowing people to write models in a language rather than plugging graphs together.

Dea: Julia is probably the most similar one. PyTorch is really easy to use, because you're not metaprogramming, whereas in TensorFlow you are metaprogramming and you have to think about the actual graphs that you are building. For example, instead of just having like a variable, you would have a variable node that you're updating. There's this level of separation with TensorFlow. Keras is a library on top of that, and so, there's also a degree of separation too because it's a high level library on top of TensorFlow.

Schuster: It used to be in the olden days, you had to literally plug together a graph. If you read the old TensorFlow examples, it was like, I have to do 20 pages of code just to have one network and something. I think these language based approaches are just neater, and cleaner.


See more presentations with transcripts


Recorded at:

Jan 07, 2022