决策树算法实现解析:从零开始的数据科学项目
2025-07-06 07:47:52作者:柯茵沙
决策树是机器学习中最基础也是最直观的算法之一,它通过一系列规则对数据进行分类或预测。本文将通过分析一个决策树实现代码,深入讲解其核心原理和实现细节。
决策树基础概念
决策树是一种树形结构,其中每个内部节点表示一个属性测试,每个分支代表测试结果,而每个叶节点代表最终的分类结果。决策树学习的目标是构建一个能够对新数据进行准确分类的树模型。
核心函数解析
1. 信息熵计算
信息熵是决策树算法中用于衡量数据集不确定性的重要指标:
def entropy(class_probabilities):
"""给定类别概率列表,计算信息熵"""
return sum(-p * math.log(p, 2) for p in class_probabilities if p)
这个函数实现了香农熵的计算,使用以2为底的对数,结果单位为比特。熵值越大表示数据越混乱,不确定性越高。
2. 数据分区与条件熵
决策树通过选择能够最大程度降低熵的属性进行分裂:
def partition_entropy_by(inputs, attribute):
"""计算按给定属性分区后的条件熵"""
partitions = partition_by(inputs, attribute)
return partition_entropy(partitions.values())
partition_entropy_by
函数计算按某个属性划分后的条件熵,帮助我们评估该属性作为分裂点的优劣。
ID3算法实现
ID3算法是经典的决策树学习算法,使用信息增益作为属性选择标准:
def build_tree_id3(inputs, split_candidates=None):
# 基础情况处理
if split_candidates is None:
split_candidates = inputs[0][0].keys()
# 计算当前节点的正负样本数
num_trues = len([label for item, label in inputs if label])
num_falses = len(inputs) - num_trues
# 终止条件1: 所有样本属于同一类别
if num_trues == 0: return False
if num_falses == 0: return True
# 终止条件2: 没有更多分裂属性
if not split_candidates:
return num_trues >= num_falses
# 选择最佳分裂属性
best_attribute = min(split_candidates,
key=partial(partition_entropy_by, inputs))
# 递归构建子树
partitions = partition_by(inputs, best_attribute)
new_candidates = [a for a in split_candidates if a != best_attribute]
subtrees = {attribute: build_tree_id3(subset, new_candidates)
for attribute, subset in partitions.items()}
# 处理缺失值情况
subtrees[None] = num_trues > num_falses
return (best_attribute, subtrees)
该实现包含了ID3算法的所有关键步骤:
- 计算当前节点的类别分布
- 处理终止条件
- 选择最佳分裂属性
- 递归构建子树
- 处理缺失值情况
决策树分类
构建好的决策树可以用于新数据的分类:
def classify(tree, input):
"""使用决策树对输入进行分类"""
if tree in [True, False]: # 叶节点直接返回值
return tree
attribute, subtree_dict = tree
subtree_key = input.get(attribute) # 获取对应属性的值
if subtree_key not in subtree_dict: # 处理未知值
subtree_key = None
subtree = subtree_dict[subtree_key]
return classify(subtree, input) # 递归分类
随机森林扩展
虽然基础实现只包含单个决策树,但代码也提供了简单的随机森林支持:
def forest_classify(trees, input):
votes = [classify(tree, input) for tree in trees]
vote_counts = Counter(votes)
return vote_counts.most_common(1)[0][0]
这个函数实现了最简单的投票机制,多个决策树对同一个输入进行分类,最终选择得票最多的类别作为结果。
实际应用示例
代码中包含了一个完整的示例,使用员工数据预测是否适合聘用:
inputs = [
({'level':'Senior','lang':'Java','tweets':'no','phd':'no'}, False),
# 更多样本数据...
]
# 构建决策树
tree = build_tree_id3(inputs)
# 对新数据进行分类
print(classify(tree, {"level": "Junior", "lang": "Java",
"tweets": "yes", "phd": "no"})) # 输出: True
算法特点与局限性
这个实现展示了决策树算法的几个关键特点:
- 直观易懂,决策过程可以可视化
- 能够处理分类问题
- 实现相对简单
但也存在一些局限性:
- 没有实现剪枝,容易过拟合
- 只支持离散属性
- 没有处理连续值
- 信息增益可能偏向选择取值较多的属性
总结
通过这个决策树实现,我们深入理解了ID3算法的核心思想和实现细节。决策树作为基础算法,虽然简单但在许多场景下仍然非常有效。理解这个基础实现有助于我们更好地掌握更复杂的树模型,如C4.5、CART以及随机森林等算法。