wiseodd/generative-models中的RBM实现解析:基于对比散度的二值受限玻尔兹曼机
本文深入分析wiseodd/generative-models项目中关于受限玻尔兹曼机(RBM)的实现,重点讲解其核心算法原理和代码实现细节。我们将从RBM的基本概念出发,逐步解析这个基于对比散度(Contrastive Divergence)的二值RBM实现。
一、受限玻尔兹曼机基础
受限玻尔兹曼机(Restricted Boltzmann Machine, RBM)是一种两层的概率生成模型,由可见层(visible layer)和隐藏层(hidden layer)组成。与普通玻尔兹曼机不同,RBM的层内节点之间没有连接,这使得其训练更加高效。
RBM的核心特点:
- 无向图模型
- 两层结构:可见层v和隐藏层h
- 层内无连接,只有层间全连接
- 能量函数定义:E(v,h) = -aᵀv - bᵀh - vᵀWh
二、代码实现解析
1. 数据准备与初始化
实现中使用了MNIST手写数字数据集,将像素值二值化处理(>0.5为1,否则为0):
mnist = input_data.read_data_sets('../MNIST_data', one_hot=True)
X_dim = mnist.train.images.shape[1] # 输入维度(784)
y_dim = mnist.train.labels.shape[1] # 标签维度(10)
mb_size = 16 # 小批量大小
h_dim = 36 # 隐藏层维度
# 初始化参数
W = np.random.randn(X_dim, h_dim) * 0.001 # 权重矩阵
a = np.random.randn(h_dim) * 0.001 # 隐藏层偏置
b = np.random.randn(X_dim) * 0.001 # 可见层偏置
2. 核心函数实现
sigmoid激活函数:用于概率计算
def sigm(x):
return 1/(1 + np.exp(-x))
推断函数(infer):计算给定可见层时隐藏层的激活概率
def infer(X):
return sigm(X @ W) # X: mb_size x x_dim → mb_size x h_dim
生成函数(generate):计算给定隐藏层时可见层的激活概率
def generate(H):
return sigm(H @ W.T) # H: mb_size x h_dim → mb_size x x_dim
3. 对比散度(CD)训练
对比散度是训练RBM的核心算法,相比传统的吉布斯采样更高效:
alpha = 0.1 # 学习率
K = 10 # 吉布斯采样步数
for t in range(1, 1001):
# 获取小批量数据并二值化
X_mb = (mnist.train.next_batch(mb_size)[0] > 0.5).astype(np.float)
# 初始化梯度
g, g_a, g_b = 0, 0, 0
for v in X_mb:
# 正向传播计算隐藏层概率
h = infer(v)
# 吉布斯采样过程
v_prime = np.copy(v)
for k in range(K):
h_prime = np.random.binomial(n=1, p=infer(v_prime))
v_prime = np.random.binomial(n=1, p=generate(h_prime))
# 计算重构后的隐藏层概率
h_prime = infer(v_prime)
# 计算梯度
grad_w = np.outer(v, h) - np.outer(v_prime, h_prime)
grad_a = h - h_prime
grad_b = v - v_prime
# 累积梯度
g += grad_w
g_a += grad_a
g_b += grad_b
# 计算平均梯度并更新参数
W += alpha * (g / mb_size)
a += alpha * (g_a / mb_size)
b += alpha * (g_b / mb_size)
4. 结果可视化
训练完成后,代码提供了可视化功能来展示RBM的学习效果:
# 测试数据推断
X = (mnist.test.next_batch(mb_size)[0] > 0.5).astype(np.float)
H = np.random.binomial(n=1, p=infer(X)) # 采样隐藏层状态
plot(H, np.sqrt(h_dim), 'H') # 绘制隐藏层表示
# 重构测试
X_recon = (generate(H) > 0.5).astype(np.float)
plot(X_recon, np.sqrt(X_dim), 'V') # 绘制重构结果
三、技术要点解析
-
对比散度算法:相比传统吉布斯采样需要大量迭代才能收敛,CD算法仅需少量步骤(K=10)就能得到较好的梯度估计,大大提高了训练效率。
-
二值化处理:MNIST原始数据是0-1之间的灰度值,代码中将其二值化为0或1,符合二值RBM的输入要求。
-
参数更新:采用简单的梯度上升法,没有使用动量等优化技巧,保持了算法的简洁性。
-
采样方法:使用np.random.binomial进行随机采样,模拟了RBM的随机性特征。
四、实际应用建议
-
参数调整:可以尝试不同的隐藏层维度(h_dim)和学习率(alpha)来优化模型性能。
-
改进训练:可以考虑加入动量项或使用更先进的优化器如Adam来加速收敛。
-
扩展应用:此RBM可作为深度信念网络(DBN)的构建模块,或用于特征提取等任务。
-
评估指标:可以添加重构误差计算等评估指标来监控训练过程。
通过这个实现,我们能够清晰地理解RBM的工作原理和训练过程。虽然现代深度学习已经发展出更复杂的生成模型,但RBM作为经典的概率图模型,其思想和原理仍然具有重要的学习和研究价值。