Differentiable Programming (1)


Today, we’ll work out the process of differentiating a simple function.

This is the base case for a computation graph. Once we have the base case, we can look forward to the inductive case, and then generalize for more complex functions.

Every function can be decomposed into a graph of elementary operators and operands. These are nodes in the graph.

Examples

We start with the simplest function: id(x)\mathbf{id}(x).

This function takes an input value and returns the same value; hence id\mathbf{id} — short for the identity function.

It doesn’t do anything interesting, but it is an excellent entry point into understanding how to implement differentiation programmatically. Later, we will add more complexity by increasing the number of inputs in two other functions:

We will also see that there is a pattern to them. Let’s get started.

1. Id

First, we define our function ff and its derivative ff':

f(x)=id(x)=xf(x)=dx=1\newcommand{\id}{\mathbf{id}} \begin{align*} f(x) &= \id(x) &= x\\ f'(x) &= dx &= 1 \end{align*}

Let’s take a moment to get familiar with the components of the diagram below.

computation node for the function Id.
@dataclass
class Node:
    data: float

    def f  (self):   # the forward function
        return self

    def f_ (self):   # the reverse function
        return Node(1)

2. Add

The binary operator (+)(+) requires 2 parameters. So, we add one more input yy in this example. We define our function ff and its derivative ff':

f(x,y)=add(x,y)=x+yf(x,y)=Dxf+Dyf=1+1\newcommand{\add}{\mathbf{add}} \begin{align*} f(x,y) &= \add(x,y) &= x + y\\ f'(x,y) &= D_xf + D_yf &= 1 + 1 \end{align*}

Let’s take a moment to get familiar with the components of the diagram below.

computation node for the function Add.
@dataclass
class Node:
    data: float

    def f  (self, other: Node):
        return Node(self.data + other.data)

    def f_ (self):
        return Node(2)

3. Mul

Finally, let’s look at the example for mul(x,y)\mathbf{mul}(x,y). This is also a binary operator. We define our function ff and its derivative ff':

f(x,y)=mul(x,y)=x×yf(x,y)=Dxf+Dyf=y+x\newcommand{\mul}{\mathbf{mul}} \begin{align*} f(x,y) &= \mul(x,y) &= x \times y\\ f'(x,y) &= D_xf + D_yf &= y + x \end{align*}

Let’s take a moment to get familiar with the components of the diagram below. Notice that other than the name difference and the polymorphic ff and ff', the form of the graph is identical to add(x,y)add(x,y).

computation node for the function Mul.
@dataclass
class Node:
    data: float

    def f  (self, other: Node):
        return Node(self.data * other.data)

    def f_ (self):
        return Node(self.data + other.data)

Finally

To evaluate the computation graph for each case, we invoke the forward and reverse functions, respectively. We gain two results:

Because there is only one node in each of these graphs, we are done.

Notice the blue line in the middle of the last two diagrams. Imagine you print out the diagram on a piece of paper. You fold the piece of paper along the blue line. You will see that the nodes in the top-half superimpose the nodes in the bottom-half: xx together with dxdx, and yy with dydy.

We use a single graph to evaluate both a function and its derivative.

Because this graph has only one node, the derivative components are discarded output. However, we will see that the member functions forward and reverse can be expanded to do more work, making use of the derivative components. They will let us model the Chain Rule for more complex functions.

Until next time.