通俗解释优化的线性感知机算法:Pocket PLA

您所在的位置:网站首页 pla算法是什么 通俗解释优化的线性感知机算法:Pocket PLA

通俗解释优化的线性感知机算法:Pocket PLA

2024-07-11 10:42| 来源: 网络整理| 查看: 265

个人网站:红色石头的机器学习之路 CSDN博客:红色石头的专栏 知乎:红色石头 微博:RedstoneWill的微博 GitHub:RedstoneWill的GitHub 微信公众号:AI有道(ID:redstonewill)

在上一篇文章:

一看就懂的感知机算法PLA

我们详细介绍了线性感知机算法模型,并使用pyhon实例,验证了PLA的实际分类效果。下图是PLA实际的分类效果:

这里写图片描述

但是,文章最后我们提出了一个疑问,就是PLA只能解决线性可分的问题。对于数据本身不是线性可分的情况,又该如何解决呢?下面,我们就将对PLA进行优化,以解决更一般的线性不可分问题。

1. Pocket PLA是什么?

首先,我们来看一下线性不可分的例子:

这里写图片描述

如上图所示,正负样本线性不可分,无法使用PLA算法进行分类,这时候需要对PLA进行优化。优化后的PCA的基本做法很简单,就是如果迭代更新后分类错误样本比前一次少,则更新权重系数 w ;没有减少则保持当前权重系数 w 不变。也就是说,可以把条件放松,即不苛求每个点都分类正确,而是容忍有错误点,取错误点的个数最少时的权重系数 w 。通常在有限的迭代次数里,都能保证得到最佳的分类线。

这种算法也被称为「口袋PLA」Pocket PLA。怎么理解呢?就好像我们在搜寻最佳分类直线的时候,随机选择错误点修正,修正后的直线放在口袋里,暂时作为最佳分类线。然后如果还有错误点,继续随机选择某个错误点修正,修正后的直线与口袋里的分类线比较,把分类错误点较少的分类线放入口袋。一直到迭代次数结束,这时候放在口袋里的一定是最佳分类线,虽然可能还有错误点存在,但已经是最少的了。

2. 数据准备

该数据集包含了100个样本,正负样本各50,特征维度为2。

data = pd.read_csv('./data/data2.csv', header=None) # 样本输入,维度(100,2) X = data.iloc[:,:2].values # 样本输出,维度(100,) y = data.iloc[:,2].values

下面我们在二维平面上绘出正负样本的分布情况。

import matplotlib.pyplot as plt plt.scatter(X[:50, 0], X[:50, 1], color='blue', marker='o', label='Positive') plt.scatter(X[50:, 0], X[50:, 1], color='red', marker='x', label='Negative') plt.xlabel('Feature 1') plt.ylabel('Feature 2') plt.legend(loc = 'upper left') plt.title('Original Data') plt.show()

这里写图片描述

很明显,从图中可以看出,正类和负类样本并不是线性可分的。这时候,我们就需要使用Pocket PLA。

3. Pocket PLA代码实现

首先分别对两个特征进行归一化处理,即:

# 均值 u = np.mean(X, axis=0) # 方差 v = np.std(X, axis=0) X = (X - u) / v # 作图 plt.scatter(X[:50, 0], X[:50, 1], color='blue', marker='o', label='Positive') plt.scatter(X[50:, 0], X[50:, 1], color='red', marker='x', label='Negative') plt.xlabel('Feature 1') plt.ylabel('Feature 2') plt.legend(loc = 'upper left') plt.title('Normalization data') plt.show()

这里写图片描述

接下来对预测直线进行初始化,包括权重 w 初始化:

# X加上偏置项 X = np.hstack((np.ones((X.shape[0],1)), X)) # 权重初始化 w = np.random.randn(3,1)

整个迭代训练过程如下:

for i in range(100): s = np.dot(X, w) y_pred = np.ones_like(y) loc_n = np.where(s < 0)[0] y_pred[loc_n] = -1 num_fault = len(np.where(y != y_pred)[0]) if num_fault == 0: break else: r = np.random.choice(num_fault) # 随机选择一个错误分类点 t = np.where(y != y_pred)[0][r] w2 = w + y[t] * X[t, :].reshape((3,1)) s = np.dot(X, w2) y_pred = np.ones_like(y) loc_n = np.where(s < 0)[0] y_pred[loc_n] = -1 num_fault2 = len(np.where(y != y_pred)[0]) if num_fault2


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3