TabNet预训练实战教程:从数据准备到模型解释
2025-07-10 06:25:33作者:齐冠琰
前言
TabNet是一种基于深度学习的表格数据建模框架,它结合了神经网络的优势和决策树的可解释性。本教程将详细介绍如何使用TabNet进行预训练和微调,特别关注自监督预训练过程,这是TabNet区别于传统表格数据建模方法的重要特性。
环境准备
在开始之前,我们需要确保环境配置正确:
%load_ext autoreload
%autoreload 2
from pytorch_tabnet.tab_model import TabNetClassifier
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_auc_score
import pandas as pd
import numpy as np
np.random.seed(0)
import os
from matplotlib import pyplot as plt
%matplotlib inline
数据准备
我们使用UCI机器学习库中的"Adult"数据集(也称为Census Income数据集),这是一个经典的分类数据集,任务是预测个人收入是否超过50K美元。
数据加载与分割
train = pd.read_csv('data/census-income.csv')
target = ' <=50K'
# 随机划分训练集、验证集和测试集
if "Set" not in train.columns:
train["Set"] = np.random.choice(["train", "valid", "test"], p=[.8, .1, .1], size=(train.shape[0],))
train_indices = train[train.Set=="train"].index
valid_indices = train[train.Set=="valid"].index
test_indices = train[train.Set=="test"].index
数据预处理
TabNet能够直接处理混合类型的特征(数值型和类别型),但仍需要进行基本的预处理:
nunique = train.nunique()
types = train.dtypes
categorical_columns = []
categorical_dims = {}
for col in train.columns:
if types[col] == 'object' or nunique[col] < 200:
print(col, train[col].nunique())
l_enc = LabelEncoder()
train[col] = train[col].fillna("VV_likely") # 处理缺失值
train[col] = l_enc.fit_transform(train[col].values)
categorical_columns.append(col)
categorical_dims[col] = len(l_enc.classes_)
else:
train.fillna(train.loc[train_indices, col].mean(), inplace=True) # 数值型特征用均值填充
TabNet预训练
TabNet的预训练采用自监督学习方式,通过掩码部分输入特征并尝试重建来学习数据的表示。
预训练模型配置
from pytorch_tabnet.pretraining import TabNetPretrainer
unsupervised_model = TabNetPretrainer(
cat_idxs=cat_idxs, # 类别型特征的索引
cat_dims=cat_dims, # 每个类别型特征的维度
cat_emb_dim=3, # 类别型特征的嵌入维度
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
mask_type='entmax', # 掩码类型,可选"sparsemax"
n_shared_decoder=1, # 解码器中共享的GLU层数
n_indep_decoder=1, # 解码器中独立的GLU层数
verbose=5,
)
执行预训练
max_epochs = 100 # 可根据实际情况调整
unsupervised_model.fit(
X_train=X_train,
eval_set=[X_valid], # 验证集用于早停
max_epochs=max_epochs,
patience=5, # 早停耐心值
batch_size=2048,
virtual_batch_size=128, # 虚拟批次大小,用于内存优化
num_workers=0,
drop_last=False,
pretraining_ratio=0.5, # 掩码比例
)
预训练结果分析
预训练完成后,我们可以检查重建效果和特征重要性:
# 重建验证集数据
reconstructed_X, embedded_X = unsupervised_model.predict(X_valid)
# 解释模型决策
unsupervised_explain_matrix, unsupervised_masks = unsupervised_model.explain(X_valid)
# 可视化掩码
fig, axs = plt.subplots(1, 3, figsize=(20,20))
for i in range(3):
axs[i].imshow(unsupervised_masks[i][:50])
axs[i].set_title(f"mask {i}")
监督学习微调
预训练完成后,我们可以使用学到的表示来初始化监督学习模型。
模型初始化与训练
clf = TabNetClassifier(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-3),
scheduler_params={"step_size":10, "gamma":0.9}, # 学习率调度器参数
scheduler_fn=torch.optim.lr_scheduler.StepLR,
mask_type='sparsemax',
verbose=5,
)
clf.fit(
X_train=X_train, y_train=y_train,
eval_set=[(X_train, y_train), (X_valid, y_valid)],
eval_name=['train', 'valid'],
eval_metric=['auc'],
max_epochs=max_epochs,
patience=20,
batch_size=1024,
virtual_batch_size=128,
from_unsupervised=loaded_pretrain, # 使用预训练模型初始化
)
训练过程监控
# 绘制损失曲线
plt.plot(clf.history['loss'])
# 绘制AUC曲线
plt.plot(clf.history['train_auc'])
plt.plot(clf.history['valid_auc'])
# 绘制学习率变化
plt.plot(clf.history['lr'])
模型评估与解释
性能评估
preds = clf.predict_proba(X_test)
test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)
print(f"BEST VALID SCORE: {clf.best_cost}")
print(f"FINAL TEST SCORE: {test_auc}")
特征重要性分析
TabNet提供了全局和局部特征重要性分析:
# 全局特征重要性
clf.feature_importances_
# 局部解释性
explain_matrix, masks = clf.explain(X_test)
# 可视化掩码
fig, axs = plt.subplots(1, 3, figsize=(20,20))
for i in range(3):
axs[i].imshow(masks[i][:50])
axs[i].set_title(f"mask {i}")
模型保存与加载
# 保存模型
saving_path_name = "./tabnet_model_test_1"
saved_filepath = clf.save_model(saving_path_name)
# 加载模型
loaded_clf = TabNetClassifier()
loaded_clf.load_model(saved_filepath)
# 验证加载的模型
loaded_preds = loaded_clf.predict_proba(X_test)
loaded_test_auc = roc_auc_score(y_score=loaded_preds[:,1], y_true=y_test)
print(f"LOADED MODEL TEST SCORE: {loaded_test_auc}")
总结
本教程详细介绍了TabNet的预训练和微调流程,展示了如何:
- 准备和预处理表格数据
- 配置并执行自监督预训练
- 使用预训练模型初始化监督学习任务
- 评估模型性能并解释模型决策
TabNet的预训练特性使其在数据稀缺的情况下尤其有价值,通过自监督学习可以利用大量未标注数据学习有用的表示。同时,TabNet提供的特征重要性分析和决策解释能力使其在实际业务场景中更具可信度。