云计算百科
云计算领域专业知识百科平台

机器学习分类算法一——决策树

一、算法原理

决策树算法通过将数据集划分为不同的子集来预测目标变量。它从根节点开始,根据某个特征对数据集进行划分,然后递归地生成更多的子节点,直到满足停止条件为止。决策树的每个内部节点表示一个特征属性上的判断条件,每个分支代表一个可能的属性值,每个叶节点表示一个分类结果。

二、参考代码

import numpy as np
import matplotlib.pyplot as plt

class TreeNode:
def __init__(self, feature_idx=None, threshold=None, left=None, right=None, class_dist=None):
self.feature_idx = feature_idx # 分割特征索引
self.threshold = threshold # 分割阈值
self.left = left # ≤ threshold 的子树
self.right = right # > threshold 的子树
self.class_dist = class_dist # 叶子节点的类别分布(用于预测)

class SimpleDecisionTree:
def __init__(self, max_depth=3):
self.max_depth = max_depth
self.tree = None
self.n_features = None

def fit(self, X, y):
y = np.asarray(y).astype(int) # 确保 y 是整数

self.n_features = X.shape[1]
self.tree = self._grow_tree(X, y, depth=0)

def _grow_tree(self, X, y, depth):
n_samples, n_features = X.shape
n_classes = len(np.unique(y))

# 停止条件:达到最大深度或所有样本属于同一类别
if depth >= self.max_depth or n_classes == 1:
class_dist = np.bincount(y, minlength=2)
return TreeNode(class_dist=class_dist)

# 寻找最佳分割
best_feature, best_threshold = self._best_split(X, y)

# 递归生长左右子树
left_idx = X[:, best_feature] <= best_threshold
right_idx = ~left_idx
left = self._grow_tree(X[left_idx], y[left_idx], depth + 1)
right = self._grow_tree(X[right_idx], y[right_idx], depth + 1)

return TreeNode(best_feature, best_threshold, left, right)

def _best_split(self, X, y):
best_gini = float('inf')
best_feature, best_threshold = None, None

for feature_idx in range(self.n_features):
thresholds = np.unique(X[:, feature_idx])
for threshold in thresholds:
gini = self._gini_impurity(X[:, feature_idx], y, threshold)
if gini < best_gini:
best_gini = gini
best_feature = feature_idx
best_threshold = threshold

return best_feature, best_threshold

def _gini_impurity(self, feature_col, y, threshold):
left_idx = feature_col <= threshold
right_idx = ~left_idx

if np.sum(left_idx) == 0 or np.sum(right_idx) == 0:
return float('inf')

left_dist = np.bincount(y[left_idx], minlength=2)
right_dist = np.bincount(y[right_idx], minlength=2)

left_prob = left_dist / np.sum(left_dist)
right_prob = right_dist / np.sum(right_dist)

gini_left = 1 np.sum(left_prob ** 2)
gini_right = 1 np.sum(right_prob ** 2)

total_gini = (np.sum(left_idx) * gini_left + np.sum(right_idx) * gini_right) / len(y)
return total_gini

def predict(self, X):
return np.array([self._predict_single(x, self.tree) for x in X])

def _predict_single(self, x, node):
if node.feature_idx is None: # 叶子节点
return np.argmax(node.class_dist)
if x[node.feature_idx] <= node.threshold:
return self._predict_single(x, node.left)
else:
return self._predict_single(x, node.right)

def visualize_tree(self, feature_names=None, class_names=None):
if feature_names is None:
feature_names = [f"Feature {i}" for i in range(self.n_features)]
if class_names is None:
class_names = ["no", "yes"]

fig, ax = plt.subplots(figsize=(10, 6))
ax.axis("off")

# 动态计算树的高度(递归深度)
def calc_depth(node):
if node is None:
return 0
left_depth = calc_depth(node.left)
right_depth = calc_depth(node.right)
return max(left_depth, right_depth) + 1

tree_depth = calc_depth(self.tree)

# 递归绘制节点(动态调整位置)
def _draw_node(node, x, y, x_offset, y_step, depth):
if node is None:
return

# 绘制当前节点
if node.feature_idx is not None: # 分割节点
ax.text(
x,
y,
f"{feature_names[node.feature_idx]}{node.threshold:.1f}",
ha="center",
va="center",
bbox=dict(facecolor="white", edgecolor="black", boxstyle="round"),
)
else: # 叶子节点
class_idx = np.argmax(node.class_dist)
class_label = class_names[class_idx]
n_samples = np.sum(node.class_dist)
prob = node.class_dist[class_idx] / n_samples
ax.text(
x,
y,
f"Predict: {class_label}\\n(n={n_samples}, p={prob:.2f})",
ha="center",
va="center",
bbox=dict(
facecolor="lightgreen" if class_idx == 1 else "lightcoral",
edgecolor="black",
boxstyle="round",
),
)

# 递归绘制子节点(动态调整位置)
if node.left or node.right:
_draw_node(node.left, x x_offset, y y_step, x_offset / 2, y_step, depth + 1)
_draw_node(node.right, x + x_offset, y y_step, x_offset / 2, y_step, depth + 1)

# 绘制带箭头的连线
ax.annotate(
"",
xy=(x x_offset, y y_step + 0.1),
xytext=(x, y 0.1),
arrowprops=dict(arrowstyle="->", color="blue"),
)
ax.annotate(
"",
xy=(x + x_offset, y y_step + 0.1),
xytext=(x, y 0.1),
arrowprops=dict(arrowstyle="->", color="red"),
)

# 初始位置和步长(缩小间距)
initial_x_offset = 0.5 / tree_depth # 根据树深度动态调整
y_step = 1 / tree_depth # 垂直间距
_draw_node(self.tree, x=0.5, y=0.9, x_offset=initial_x_offset, y_step=y_step, depth=0)

plt.tight_layout()
plt.show()

三、代码分析

1、决策树生成逻辑

在这里插入图片描述

2、决策树预测逻辑

在这里插入图片描述

3、最佳分割线

def _best_split(self, X, y):
best_gini = float('inf')
best_feature, best_threshold = None, None

for feature_idx in range(self.n_features):
thresholds = np.unique(X[:, feature_idx])
for threshold in thresholds:
gini = self._gini_impurity(X[:, feature_idx], y, threshold)
if gini < best_gini:
best_gini = gini
best_feature = feature_idx
best_threshold = threshold

return best_feature, best_threshold

def _gini_impurity(self, feature_col, y, threshold):
left_idx = feature_col <= threshold
right_idx = ~left_idx

if np.sum(left_idx) == 0 or np.sum(right_idx) == 0:
return float('inf')

left_dist = np.bincount(y[left_idx], minlength=2)
right_dist = np.bincount(y[right_idx], minlength=2)

left_prob = left_dist / np.sum(left_dist)
right_prob = right_dist / np.sum(right_dist)

gini_left = 1 np.sum(left_prob ** 2)
gini_right = 1 np.sum(right_prob ** 2)

total_gini = (np.sum(left_idx) * gini_left + np.sum(right_idx) * gini_right) / len(y)
return total_gini

(1)Gini不纯度

定义:Gini 不纯度衡量的是一个数据集中 随机选取两个样本,其类别不一致的概率。 公式:

G

i

n

i

=

1

k

=

1

K

p

k

2

Gini =1-\\sum_{k=1}^{K}p_k^2

Gini=1k=1Kpk2 其中,其中:K是类别总数(如二分类问题中 K=2)。

p

k

p_k

pk 是第 k类样本在数据集中的比例。 计算步骤:

1)选择一个分割特征和阈值

如age ≤ 40

2)划分数据:

左子集:满足条件的数据(age ≤ 40)。 右子集:不满足条件的数据(age > 40)。

3)计算左右子集的 Gini 不纯度:

对左子集和右子集分别计算 Gini 不纯度。

4)计算加权平均 Gini 不纯度:

T

o

t

e

l

_

G

i

n

i

=

n

l

e

f

t

n

t

o

t

e

l

G

i

n

i

l

e

f

t

+

n

r

i

g

h

t

n

t

o

t

e

l

G

i

n

i

r

i

g

h

t

Totel \\_ Gini=\\frac{n_{left}}{n_{totel}}*Gini_{left}+\\frac{n_{right}}{n_{totel}}*Gini_{right}

Totel_Gini=ntotelnleftGinileft+ntotelnrightGiniright 其中,

n

r

i

g

h

t

n

l

e

f

t

n

t

o

t

e

l

n_{right}、n_{left}、n_{totel}

nrightnleftntotel分别为左子集、右子集和总集的样本个数。

参考文献

1、【总结】机器学习中的15种分类算法

赞(0)
未经允许不得转载:网硕互联帮助中心 » 机器学习分类算法一——决策树
分享到: 更多 (0)

评论 抢沙发

评论前必须登录!