Differentiable Programming (2)


In this post, we will work out the differentiation process of a very simple function.

We will use the simplest function id\mathbf{id} to drill down the basics.

As seen previously, this function doesn’t do anything other than returning the exact input it is given. However, precisely because of this trivial behavior, we can focus on the keys of implementation rather than getting bogged down by the rules of math.

Set Up

We define our functional F\mathbf{F} and its derivative:

given:f(x)=x=id(x)f(x)=1\newcommand{\id}{\mathbf{id}} \begin{align*} \text{given}&:& f(x) &= x = \id(x) \\ && f'(x) &= 1 \end{align*} F(x)=id(id(id(x)))=ididid(x)D(x)=1\newcommand{\id}{\mathbf{id}} \begin{align*} \mathbf{F}(x) &= \id(\id(\id(x))) = \id\circ\id\circ\id (x)\\ \mathbf{D}(x) &= 1 \end{align*}

Next, we model the computation graph. Because the functional F\mathbf{F} takes a single-variable input, the evaluation process maps nicely to a linked list.

Modeling

Now, we map and connect the internal of the nodes together. Remember from the last post that a node structure hosts two chain links:

  1. forward: nothing to compute so we simply pass on the symbol xx
  2. reverse: computing is trivial (the unitx\mathbf{unit_x} resolves to 1) — we just pass on 1

This time, we implement some working code. Now, a node also takes a function to be used in the forward pass.

@dataclass
class Node:
    data: float
    func: lambda: None

    def forward(self, x: Node):
        return self.func(self.data)

    def reverse(self, d):
        return 1  # for this example, the input `d` doesn't matter

# the identity function
def idFunc(x: float):
    return x

1. Evaluate Path

2. Differentiate Path

Implementation

We assume that the expression is already parsed and ordered into the following list of operators:

id1 = Node(1, idFunc, "id")
id2 = Node(id1.data, idFunc, "id")
id3 = Node(id2.data, idFunc, "id")

expression = [id1, id2, id3]

Next, we define the accumulator functions to traverse the list and evaluate F\mathbf{F} as well as its derivative D\mathbf{D}:

def evaluate(graph: list[Node]):
    y = None
    for each in graph:
        y = each.forward()
    return y

def differentiate(graph: list[Node]):
    _, *rest = reversed(graph)
    d = 1
    for each in rest:
        d = each.reverse(d)
    return d

Finally, a simple sanity test:

y  = assert(1 == evaluate(expression))
dy = assert(1 == differentiate(expression))

Mathing the Math

Let’s take a moment and think about where dd comes from and why it is needed. Otherwise, we wouldn’t be able to make full audit of more complex graphs.

What actually happens is that every node is its own symbolic unit\mathbf{unit}it can be used as is or unfolded into its constituents. In other words, a node can either resolve to unit\mathbf{unit} or its local derivative. I rewrote the definition, changing only the symbols to make clear that each is a scalar quantity:

DzF=unitz,given that F=z DyF=DzFunityDxF=DzFDyFunitxDaF=DzFDyFDabunita\begin{align*} D_z\mathbf{F} &=\mathbf{\color{lightgray}unit_z}\quad,\quad\text{given that }\mathbf{F}=z\\\\\ D_y\mathbf{F} &= D_zF \cdot\mathbf{\color{lightgray}unit_{y}}\\\\ D_x\mathbf{F} &= D_zF \cdot D_yF \cdot\mathbf{\color{lightgray}unit_{x}}\\\\ \vdots\\\\ D_a\mathbf{F} &= D_zF \cdot D_yF \cdot\quad\dots\quad \cdot D_ab \cdot\mathbf{\color{lightgray}unit_{a}}\\\\ \end{align*}

And because they are scalar quantities we can fold them together as a single quantity dd. The symbol dd consumes the local derivative at each node, and then has its new state propagated to the next node:

DzF=unitzDyF=DzFdunityDxF=dDyFdunitxDaF=dDabdunita\begin{align*} D_z\mathbf{F} &=\mathbf{\color{lightgray}unit_z}\\\\ D_y\mathbf{F} &= \underbrace{D_zF}_d \cdot\mathbf{\color{lightgray}unit_{y}}\\\\ D_x\mathbf{F} &= \underbrace{d \cdot D_yF}_d \cdot\mathbf{\color{lightgray}unit_{x}}\\\\ \vdots\\\\ D_a\mathbf{F} &= \underbrace{d \cdot D_ab}_d \cdot\mathbf{\color{lightgray}unit_{a}}\\\\ \end{align*}

At the final node bb — the basis of input(s) is a\mathbf{a}:

DaF=DzFDyFreducer symbol dDabunita=dDabunita\begin{align*} D_a\mathbf{F} &=& \underbrace{ D_zF \cdot D_yF \cdot\quad\dots\quad }_\text{reducer symbol $d$}& &\cdot& D_ab \cdot\mathbf{\color{lightgray}unit_{a}} \\\\ &=& d& &\cdot& D_ab\cdot {\color{lightgray}\mathbf{unit_a}} \end{align*}

Simply put, at the very first node zz, the derivative resolves to unitz\mathbf{unit_z} or 1 — we treat this subgraph as a unit of computation. From there, we simply yield 1 for every id\mathbf{id} operator, which means we could skip the list and just use the last node. It is possible to hand wave dd away for this example; however in more complex operator nodes dd might no longer yield just 1, like in the example below:

given:F(x,y)=x+yDx(x,y)=1Dy(x,y)=1then:F(x,y)d=Dx+Dyd=2\begin{align*} \text{given}&:& F(x,y) &= x + y\\ && D_x(x,y) &= 1\\ && D_y(x,y) &= 1\\\\ \text{then}&:& \underbrace{F'(x,y)}_d&= D_x + D_y\\ && d&= 2\\ \end{align*}

Now we understand where dd comes from — and why we must collect, compute and pass on its new state. On the other hand, a node’s total derivative is just the sum of its partial derivatives. So, we can be sure that dd is always a scalar quantity.

Summary

In this post, we have worked out the process of differentiating over a simple function. This process operates on top of a computation graph (in this example, a simple linked list).

We will see that, as a function grows in complexity, we naturally need to track incoming input into each node. This is required in order to compute the local derivatives and ultimately propagate them through the graph to get the total derivative.

In the next post, we will apply this process to functions of multiple inputs.

Until next time.