Differentiable Programming (2)
In this post, we will work out the differentiation process of a very simple function.
We will use the simplest function to drill down the basics.
As seen previously, this function doesn’t do anything other than returning the exact input it is given. However, precisely because of this trivial behavior, we can focus on the keys of implementation rather than getting bogged down by the rules of math.
Set Up
We define our functional and its derivative:
Next, we model the computation graph. Because the functional takes a single-variable input, the evaluation process maps nicely to a linked list.
Modeling
Now, we map and connect the internal of the nodes together. Remember from the last post that a node structure hosts two chain links:
- forward: nothing to compute so we simply pass on the symbol
- reverse: computing is trivial (the resolves to 1) — we just pass on 1
This time, we implement some working code. Now, a node also takes a function to be used in the forward pass.
@dataclass
class Node:
data: float
func: lambda: None
def forward(self, x: Node):
return self.func(self.data)
def reverse(self, d):
return 1 # for this example, the input `d` doesn't matter
# the identity function
def idFunc(x: float):
return x
1. Evaluate Path
2. Differentiate Path
Implementation
We assume that the expression is already parsed and ordered into the following list of operators:
id1 = Node(1, idFunc, "id")
id2 = Node(id1.data, idFunc, "id")
id3 = Node(id2.data, idFunc, "id")
expression = [id1, id2, id3]
Next, we define the accumulator functions to traverse the list and evaluate as well as its derivative :
def evaluate(graph: list[Node]):
y = None
for each in graph:
y = each.forward()
return y
def differentiate(graph: list[Node]):
_, *rest = reversed(graph)
d = 1
for each in rest:
d = each.reverse(d)
return d
Finally, a simple sanity test:
y = assert(1 == evaluate(expression))
dy = assert(1 == differentiate(expression))
Mathing the Math
Let’s take a moment and think about where comes from and why it is needed. Otherwise, we wouldn’t be able to make full audit of more complex graphs.
What actually happens is that every node is its own symbolic — it can be used as is or unfolded into its constituents. In other words, a node can either resolve to or its local derivative. I rewrote the definition, changing only the symbols to make clear that each is a scalar quantity:
And because they are scalar quantities we can fold them together as a single quantity . The symbol consumes the local derivative at each node, and then has its new state propagated to the next node:
At the final node — the basis of input(s) is :
Simply put, at the very first node , the derivative resolves to or 1 — we treat this subgraph as a unit of computation. From there, we simply yield 1 for every operator, which means we could skip the list and just use the last node. It is possible to hand wave away for this example; however in more complex operator nodes might no longer yield just 1, like in the example below:
Now we understand where comes from — and why we must collect, compute and pass on its new state. On the other hand, a node’s total derivative is just the sum of its partial derivatives. So, we can be sure that is always a scalar quantity.
Summary
In this post, we have worked out the process of differentiating over a simple function. This process operates on top of a computation graph (in this example, a simple linked list).
We will see that, as a function grows in complexity, we naturally need to track incoming input into each node. This is required in order to compute the local derivatives and ultimately propagate them through the graph to get the total derivative.
In the next post, we will apply this process to functions of multiple inputs.
Until next time.