1. 引言
在过去十余年中,多层感知机(MLP)作为深度学习的基础结构,被广泛应用于图像识别、自然语言处理、科学计算等领域。然而,随着任务复杂度和模型规模的不断增长,MLP 在参数效率、可解释性、泛化能力等方面的局限逐渐显现。传统 MLP 固定在节点上的非线性激活函数(如 ReLU、Sigmoid、Tanh)在表达能力和适应性上存在一定瓶颈,尤其是在需要高度结构化表示或符号化理解的任务中,难以兼顾性能与可解释性。
2024 年,麻省理工学院(MIT)团队提出了一种基于 Kolmogorov-Arnold 表示定理的新型神经网络架构——Kolmogorov-Arnold Networks(KAN)。该架构将非线性映射从节点移动到边上,并通过可学习的单变量函数(如样条函数)进行组合,从而在理论上保证任意连续多变量函数的精确表示。KAN 的设计理念不仅保留了深度神经网络的通用逼近能力,还显著提升了模型在低样本、符号化建模以及科学计算场景中的表现潜力。
本文旨在在标准视觉任务中对 KAN 的可行性进行验证。具体而言,我们将复现 MIT 团队在论文中提出的核心思想与实现细节,并在 MNIST 与 CIFAR-10 数据集上进行对比实验,评估其在准确率、训练效率、参数规模和可解释性方面相对于 MLP 的优势与不足。通过这一过程,我们希望探讨 KAN 在 2025 年是否具备替代 MLP 的现实可能性。
2. 理论基础与架构机制
Kolmogorov-Arnold 表示定理是 20 世纪数学分析的重要成果之一,其核心观点是:任意连续的多元函数,都可以通过有限个单变量函数的加和与复合来表示。KAN(Kolmogorov-Arnold Networks)正是基于这一思想,将传统神经网络中位于节点上的固定激活函数,迁移至连接边上,并让这些非线性映射变为可学习的单变量函数。
在 KAN 中,每一条连接边都不再是单一的权重,而是一个可学习的函数映射(常用样条函数实现)。节点之间的计算不再是简单的线性加权,而是多个单变量函数的组合与累加。这种架构的直接好处是:
表达能力提升:边上可学习的函数能自适应数据分布,减少对大规模参数的依赖。
参数效率更高:在相同精度下,所需参数量可比 MLP 少。
可解释性增强:每条边的函数都可以单独分析、可视化,便于理解模型如何映射输入到输出。
在实现上,KAN 通常会使用样条插值(Spline)来表示可学习函数,因为样条在数值稳定性、光滑性和可微性上都表现出色,同时梯度计算也相对容易。MIT 的开源实现提供了 pykan 库,可以直接在 PyTorch 中定义和训练 KAN 模型。
下面给出一个使用 pykan 库实现简单 KAN 架构的示例代码(以 MNIST 输入维度为例):
# 安装 pykan 库(如果尚未安装)
# pip install pykan torch torchvision
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from pykan import KAN
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root=\’./data\’, train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root=\’./data\’, train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
# 定义 KAN 模型
# 这里输入为 28*28 = 784 维,输出为 10 类
# hidden_sizes 定义每层的隐藏单元数
model = KAN(
layers_hidden=[784, 64, 32, 10], # 网络层结构
spline_order=3, # 样条阶数
grid_size=16 # 样条网格密度
)
# 定义损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练函数
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.view(data.size(0), -1).to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f\”Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}\”)
# 测试函数
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.view(data.size(0), -1).to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
test_loss /= len(test_loader.dataset)
print(f\”\\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} \”
f\”({100. * correct / len(test_loader.dataset):.2f}%)\\n\”)
# 设备选择
device = torch.device(\”cuda\” if torch.cuda.is_available() else \”cpu\”)
model.to(device)
# 训练与测试循环
for epoch in range(1, 6):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
该示例展示了如何在 MNIST 数据集上快速搭建并训练一个 KAN 模型。与 MLP 不同的是,这里的每一条连接边都使用了可学习的样条函数来进行映射,这使得网络在学习过程中能够灵活地适应数据的非线性特征。
3. 实验设计与复现方案
为了系统评估 KAN 在标准视觉任务上的表现,我们设计了对比实验,目标是验证其在准确率、训练效率、参数规模与可解释性方面的优势与不足。本实验采用 MNIST 和 CIFAR-10 两个经典数据集,并以 MLP 作为基准模型。实验流程分为数据准备、模型定义、训练流程与可解释性可视化四个部分。
实验目标
-
在 MNIST 与 CIFAR-10 数据集上复现
评论前必须登录!
注册