Introduction
Kolmogorov-Arnold Networks, or KANs, represent a significant advancement in neural networks. These networks, based on the Kolgomorov-Arnold representation theorem, offer a potential alternative to Multilayer Perceptrons (MLP). In contrast to MLPs, where fixed activation functions are used at each node, KANs utilize learnable activation functions on edges, replacing linear weights with parameterized splines.
A recent paper titled “KAN: Kolmogorov-Arnold Networks” presented by a research team from the Massachusetts Institute of Technology, California Institute of Technology, Northeastern University, and The NSF Institute for Artificial Intelligence and Fundamental Interactions highlights KANs as a promising replacement for MLPs.
Learning Objectives
Learn and understand the Kolmogorov-Arnold Network, a new type of neural network that offers accuracy and interpretability.
Implement Kolmogorov-Arnold Networks using Python libraries.
Understand the key differences between Multi-Layer Perceptrons and Kolmogorov-Arnold Networks.
This article is part of the Data Science Blogathon.
Kolmogorov-Arnold representation theorem
The Kolmogorov-Arnold representation theorem states that any multivariate continuous function can be expressed as a sum of univariate functions and additions.
This theorem expands the function to more than 2n+1 layers, making it applicable to real-world, smooth functions.
What are Multi-layer Perceptrons?
Multi-layer Perceptrons (MLP) are the simplest form of Artificial Neural Networks (ANNs), where information flows in one direction, from input to output. These networks do not have cycles or loops, and they are a type of feedforward neural network.
Working of MLPs
Input Layer: Nodes in the input layer represent the features of the input data, with each node corresponding to a specific feature.
Hidden Layers: MLPs include one or more hidden layers between the input and output layers, allowing the network to learn complex patterns and relationships in the data.
Output Layer: The output layer generates the final predictions or classifications.
Connections and Weights: Each connection between neurons in adjacent layers has a weight associated with it, determining its strength. These weights are adjusted during training through backpropagation to minimize the difference between predictions and actual target values.
Activation Functions: Neurons, except those in the input layer, apply an activation function to the weighted sum of their inputs, introducing non-linearity into the network.
Simplified Formula
MLPs are built on the universal approximation theorem, allowing them to represent a wide range of complex functions. However, MLPs have fixed activation functions on each node, which limits their flexibility and interpretability.
Kolmogorov-Arnold Networks (KANs)
Kolmogorov-Arnold Networks are neural networks with learnable activation functions. Unlike MLPs, where activation functions are fixed at each node, KANs have learnable activation functions on edges, replacing linear weights with parametrized splines.
Advantages of KANs
KANs offer several advantages:
Greater Flexibility: KANs are highly flexible due to their activation functions and model architecture, allowing better representation of complex data.
Adaptable Activation Functions: The activation functions in KANs are not fixed like in MLPs. They can adapt and adjust to different data patterns, effectively capturing diverse relationships.
Better Complexity Handling: By replacing linear weights in MLPs with parametrized splines, KANs can handle complex, non-linear data more effectively.
Superior Accuracy: KANs have demonstrated better accuracy in handling high-dimensional data.
Highly Interpretable: KANs reveal structures and topological relationships in the data, making them easily interpretable.
Diverse Applications: KANs can perform various tasks such as regression, solving partial differential equations, and continual learning.
Also read: Multi-Layer Perceptrons: Notations and Trainable Parameters
Simple Implementation of KANs
Implementing KANs with a simple example involves creating a custom dataset for the function f(x, y) = exp(cos(pi*x) + y^2). This function takes two inputs, calculates the cosine of pi*x, adds the square of y to it, and then calculates the exponential of the result.
Requirements of Python library version:
Python==3.9.7
matplotlib==3.6.2
numpy==1.24.4
scikit_learn==1.1.3
torch==2.2.2
!pip install git+https://github.com/KindXiaoming/pykan.git
import torch
import numpy as np
##create a dataset
def create_dataset(f, n_var=2, n_samples=1000, split_ratio=0.8):
# Generate random input data
X = torch.rand(n_samples, n_var)
# Compute the target values
y = f(X)
# Split into training and test sets
split_idx = int(n_samples * split_ratio)
train_input, test_input = X[:split_idx], X[split_idx:]
train_label, test_label = y[:split_idx], y[split_idx:]
return {
‘train_input’: train_input,
‘train_label’: train_label,
‘test_input’: test_input,
‘test_label’: test_label
}
# Define the new function f(x, y) = exp(cos(pi*x) + y^2)
f = lambda x: torch.exp(torch.cos(torch.pi*x[:, [0]]) + x[:, [1]]**2)
dataset = create_dataset(f, n_var=2)
print(dataset[‘train_input’].shape, dataset[‘train_label’].shape)
##output: torch.Size([800, 2]) torch.Size([800, 1])
from kan import *
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. “`html
# cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
# plot KAN at initialization
model(dataset[‘train_input’]);
model.plot(beta=100)
## train the model
model.train(dataset, opt=”LBFGS”, steps=20, lamb=0.01, lamb_entropy=10.)
## output: train loss: 7.23e-02 | test loss: 8.59e-02
## output: | reg: 3.16e+01 : 100%|██| 20/20 [00:11<00:00, 1.69it/s]
model.plot()
model.prune()
model.plot(mask=True)
model = model.prune()
model(dataset[‘train_input’])
model.plot()
model.train(dataset, opt=”LBFGS”, steps=100)
model.plot()
Code Explanation
Install the Pykan library from Git Hub.
Import libraries.
The create_dataset function generates random input data (X) and computes the target values (y) using the function f. The dataset is then split into training and test sets based on the split ratio. The parameters of this function are:
f: function to generate the target values.
n_var: number of input variables.
n_samples: total number of samples
split_ratio: ratio to split the dataset into training and test sets, and it returns a dictionary containing training and test inputs and labels.
Create a function of the form: f(x, y) = exp(cos(pi*x) + y^2)
Call the function create_dataset to create a dataset using the previously defined function f with 2 input variables.
Print the shape of training inputs and their labels.
Initialize a KAN model with 2-dimensional inputs, 1-dimensional output, 5 hidden neurons, cubic spline (k=3), and 5 grid intervals (grid=5)
Plot the KAN model at initialization.
Train the KAN model using the provided dataset for 20 steps using the LBFGS optimizer.
After training, plot the trained model.
Prune the model and plot the pruned model with the masked neurons.
Prune the model again, evaluate it on the training input, and plot the pruned model.
Re-train the pruned model for an additional 100 steps.
MLP vs KAN
MLP
KAN
Fixed node activation functions
Learnable activation functions
Linear weights
Parametrized splines
Less interpretable
More interpretable
Less flexible and adaptable as compared to KANs
Highly flexible and adaptable
Faster training time
Slower training time
Based on Universal Approximation Theorem
Based on Kolmogorov-Arnold Representation Theorem
Conclusion
The invention of KANs indicates a step towards advancing deep learning techniques. By providing better interpretability and accuracy than MLPs, they can be a better choice when interpretability and accuracy of the results are the main objective. However, MLPs can be a more practical solution for tasks where speed is essential. Research is continuously happening to improve these networks, yet for now, KANs represent an exciting alternative to MLPs.
Key Takeaways
KANs are a new type of neural network with learnable activation functions on edges based on the Kolmogorov-Arnold representation theorem.
KANs provide greater flexibility and adaptability, better handling of complex data, superior accuracy, and higher interpretability than MLPs.
The blog details how to implement KANs in Python, including dataset creation, model initialization, training, and visualization.
KANs differ from MLPs by having learnable activation functions and parametrized splines, making them more interpretable but slower to train.
KANs represent an advanced alternative to MLPs, particularly when accuracy and interpretability are prioritized over training speed.
The media shown in this article are not owned by Analytics Vidhya and is used at the Author’s discretion.
Frequently Asked Questions
A. Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, Marin Soljaci, Thomas Y. Hou, Max Tegmark are the researchers involved in the dQevelopment of KANs.
A. Fixed activation functions are mathematical functions applied to the outputs of neurons in neural networks. These functions remain constant throughout training and are not updated or adjusted based on the network’s learning. Ex: Sigmoid, tanh, ReLU.
Learnable activation functions are adaptive and modified during the training process. Instead of being predefined, they are updated through backpropagation, allowing the network to learn the most suitable activation functions.
A. One limitation of KANs is their slower training time due to their complex architecture. They require more computations during the training process since they replace the linear weights with spline-based functions that require additional computations to learn and optimize.
A. If your task requires more accuracy and interpretability and training time isn’t limited, you can proceed with KANs. If training time is critical, MLPs are a practical option.
A. The LBFGS optimizer stands for “Limited-memory Broyden–Fletcher–Goldfarb–Shanno” optimizer. It is a popular algorithm for parameter estimation in machine learning and numerical optimization.
“` Please provide the HTML tags that need to be rewritten.
Source link