Differentiable Programming (1)
Today, we’ll work out the process of differentiating a simple function.
This is the base case for a computation graph. Once we have the base case, we can look forward to the inductive case, and then generalize for more complex functions.
Every function can be decomposed into a graph of elementary operators and operands. These are nodes in the graph.
Examples
We start with the simplest function: .
This function takes an input value and returns the same value; hence — short for the identity function.
It doesn’t do anything interesting, but it is an excellent entry point into understanding how to implement differentiation programmatically. Later, we will add more complexity by increasing the number of inputs in two other functions:
We will also see that there is a pattern to them. Let’s get started.
1. Id
First, we define our function and its derivative :
Let’s take a moment to get familiar with the components of the diagram below.
- The components:
- the Node, the Forward chain, the Reverse chain
- The game rules for constructing the graph:
- at each node, we define
- the forward function to compute
- the reverse function to compute its derivative
- at each node, we define
@dataclass
class Node:
data: float
def f (self): # the forward function
return self
def f_ (self): # the reverse function
return Node(1)
2. Add
The binary operator requires 2 parameters. So, we add one more input in this example. We define our function and its derivative :
Let’s take a moment to get familiar with the components of the diagram below.
@dataclass
class Node:
data: float
def f (self, other: Node):
return Node(self.data + other.data)
def f_ (self):
return Node(2)
3. Mul
Finally, let’s look at the example for . This is also a binary operator. We define our function and its derivative :
Let’s take a moment to get familiar with the components of the diagram below. Notice that other than the name difference and the polymorphic and , the form of the graph is identical to .
@dataclass
class Node:
data: float
def f (self, other: Node):
return Node(self.data * other.data)
def f_ (self):
return Node(self.data + other.data)
Finally
To evaluate the computation graph for each case, we invoke the forward and reverse functions, respectively. We gain two results:
- the result of
- its derivative
Because there is only one node in each of these graphs, we are done.
Notice the blue line in the middle of the last two diagrams. Imagine you print out the diagram on a piece of paper. You fold the piece of paper along the blue line. You will see that the nodes in the top-half superimpose the nodes in the bottom-half: together with , and with .
We use a single graph to evaluate both a function and its derivative.
Because this graph has only one node, the derivative components are discarded output. However, we will see that the member functions forward and reverse can be expanded to do more work, making use of the derivative components. They will let us model the Chain Rule for more complex functions.
Until next time.