Implementing a Simple Neural Network Framework from Scratch
Despite my experience in the AI ecosystem, I recently realized that I didn’t fully understand backpropagation and gradient updates within neural networks. In this article, I aim to rectify this by implementing a simple neural network framework from scratch. This framework will provide a thorough yet easy-to-follow dive into the topic.
Fundamentally, a neural network is a mathematical function that maps inputs to desired outputs. We can think of any neural network as a function, even a simple one with two layers and one input. Let’s consider the following example:
“`
A simple neural net with two layers and a ReLU activation. Here, the linear networks have weights wₙ and biases bₙ.
“`
We can represent this neural network as a function by going layer by layer, starting from the input. For this example, the function would look like this:
“`
At the input, we start with the identity function pred(x) = x
At the first linear layer, we get pred(x) = w₁x + b₁
The ReLU nets us pred(x) = max(0, w₁x + b₁)
At the final layer, we get pred(x) = w₂(max(0, w₁x + b₁)) + b₂
“`
Although these functions can become complex with more complicated networks, the point is that we can represent neural networks using these mathematical functions.
We can make these functions more useful for computation by parsing them into a syntax tree. Each leaf node in the tree represents a parameter, constant, or input, and the other nodes represent elementary operations that take their children as arguments. By thinking of a neural network as a tree of elementary operations, we can easily perform recursive algorithms for forward propagation and backpropagation.
Here is an example of a recursive neural network class implemented in Python:
“`python
from dataclasses import dataclass, field
from typing import List
@dataclass
class NeuralNetNode:
“””A node in our neural network tree”””
children: List[‘NeuralNetNode’] = field(default_factory=list)
def op(self, x: List[float]) -> float:
“””The operation that this node performs”””
raise NotImplementedError
def forward(self) -> float:
“””Evaluate this node on the given input”””
return self.op([child.forward() for child in self.children])
# This is just for convenience
def __call__(self) -> List[float]:
return self.forward()
def __repr__(self):
return f'{self.__class__.__name__}({self.children})’
“`
Suppose we have a differentiable loss function for our neural network, such as Mean Squared Error (MSE). We can update the parameters (represented as green circles in the tree) based on the loss value. To do this, we need to calculate the derivative of the loss function with respect to each parameter. The chain rule allows us to compute these derivatives by breaking them down into simpler derivatives.
The recursive tree structure of the neural network works well with the chain rule. Each elementary operation knows its derivative with respect to all of its arguments. By propagating the derivative from the parent operation to the child operations through simple multiplication, we can compute the derivative of each node with respect to the loss function. Here is an example of the backward propagation algorithm implemented in Python:
“`python
@dataclass
class NeuralNetNode:
…
def grad(self) -> List[float]:
“””The gradient of this node with respect to its inputs”””
raise NotImplementedError
def backward(self, derivative_from_parent: float):
“””Propagate the derivative from the parent to the children”””
self.on_backward(derivative_from_parent)
deriv_wrt_children = self.grad()
for child, derivative_wrt_child in zip(self.children, deriv_wrt_children):
child.backward(derivative_from_parent * derivative_wrt_child)
def on_backward(self, derivative_from_parent: float):
“””Hook for subclasses to override. Things like updating parameters”””
pass
“`
We can now define input nodes, parameter nodes, and operations in our framework. Here are examples of how to implement them in Python:
“`python
from dataclasses import dataclass, field
import random
@dataclass
class Input(NeuralNetNode):
“””A leaf node that represents an input to the network”””
value: float = 0.0
def op(self, x):
return self.value
def grad(self) -> List[float]:
return [1.0]
def __repr__(self):
return f'{self.__class__.__name__}({self.value})’
@dataclass
class Parameter(NeuralNetNode):
“””A leaf node that represents a parameter to the network”””
value: float = field(default_factory=lambda: random.uniform(-1, 1))
learning_rate: float = 0.01
def op(self, x):
return self.value
def grad(self) -> List[float]:
return [1.0]
def on_backward(self, derivative_from_parent: float):
self.value -= derivative_from_parent * self.learning_rate
def __repr__(self):
return f'{self.__class__.__name__}({self.value})’
@dataclass
class Operation(NeuralNetNode):
“””A node that performs an operation on its inputs”””
pass
“`
We can also implement specific operations such as addition, multiplication, ReLU, and sigmoid:
“`python
from typing import List
import math
@dataclass
class Add(Operation):
“””A node that adds its inputs”””
def op(self, x):
return sum(x)
def grad(self) -> List[float]:
return [1.0] * len(self.children)
@dataclass
class Multiply(Operation):
“””A node that multiplies its inputs”””
def op(self, x):
return math.prod(x)
def grad(self) -> List[float]:
grads = []
for i in range(len(self.children)):
cur_grad = 1
for j in range(len(self.children)):
if i == j:
continue
cur_grad *= self.children[j].forward()
grads.append(cur_grad)
return grads
@dataclass
class ReLU(Operation):
“””A node that applies the ReLU function to its input.”””
def op(self, x):
return max(0, x[0])
def grad(self) -> List[float]:
return [1.0 if self.children[0].forward() > 0 else 0.0]
@dataclass
class Sigmoid(Operation):
“””A node that applies the sigmoid function to its input.”””
def op(self, x):
return 1 / (1 + math.exp(-x[0]))
def grad(self) -> List[float]:
return [self.forward() * (1 – self.forward())]
“`
To define a neural network in our framework, we can construct a tree-like structure. Here is an example of a simple linear classifier implemented using our framework:
“`python
linear_classifier = Add([Multiply([Parameter(), Input()]), Parameter()])
“`
To use our models for prediction, we need to populate the inputs in the tree and call the forward() method on the parent node. Here is an example of how to do this in Python:
“`python
class Operation(NeuralNetNode):
…
def find_input_nodes(self) -> List[Input]:
“””Find all of the input nodes in the subtree rooted at this node”””
input_nodes = []
for child in self.children:
if isinstance(child, Input):
input_nodes.append(child)
elif isinstance(child, Operation):
input_nodes.extend(child.find_input_nodes())
return input_nodes
def predict(self, inputs: List[float]) -> float:
“””Evaluate the network on the given inputs”””
input_nodes = self.find_input_nodes()
assert len(input_nodes) == len(inputs)
for input_node, value in zip(input_nodes, inputs):
input_node.value = value
return self.forward()
“`
Training our models is now straightforward. Here is an example of a training function implemented in Python:
“`python
from typing import Callable, Tuple
def train_model(model: Operation,
loss_fn: Callable[[float, float], float],
optimizer: Callable[[float, float, float], float],
data: List[Tuple[List[float], float]],
num_epochs: int):
for epoch in range(num_epochs):
for inputs, target in data:
prediction = model.predict(inputs)
loss = loss_fn(prediction, target)