最小二乘法公式推导及Python实现

您所在的位置:网站首页 python最小二乘估计 最小二乘法公式推导及Python实现

最小二乘法公式推导及Python实现

2024-07-14 04:50| 来源: 网络整理| 查看: 265

机器学习使用线性回归方法建模时,求损失函数最优解需要用到最小二乘法。相信很多朋友跟我一样,想先知道公式是什么,然后再研究它是怎么来的。所以不多说,先上公式。

对于线性回归方程\(f(x) = ax + b\),由最小二乘法得: $$a = \frac{\sum (x_{i}-\overline{x})(y_{i}-\overline{y})}{\sum (x_{i}-\overline{x})^{2}}$$ $$b = \overline{y}-a\overline{x}$$ 式中,\((x_{i}, y_{i})\)为实验所得的一组数据的真实值,\(\overline{x}为x_{i}\)的平均数,\(\overline{y}为y_{i}\)的平均数。 接下来推导一下公式是怎么得来的:

设损失函数:

\[M = \sum[y_{i}-f(x_{i})]^{2} \]

由于是线性回归,式中f(x)是线性函数,令\(f(x) = ax + b\)

现在损失函数M表示为:

\[M = \sum[y_{i}-(ax_{i}+b)]^{2}\; (1) \]

最小二乘法是求f(x)的参数a,b,使得损失函数M取得最小值。

即求M = M(a, b)在哪些点取得最小值。由多元函数极值求法,上述问题可以分别对a,b求偏导数,通过解方程组 $$\left{\begin{matrix}M_{a}(a, b) = 0 \ M_{b}(a, b) = 0 \end{matrix}\right.$$ 来解决,即令 $$\left{\begin{matrix}\frac{\partial M}{\partial a} = 0; (2) \ \ \frac{\partial M}{\partial b} = -2\sum [y_{i}-(ax_{i}+b)] = 0; (3) \end{matrix}\right.$$

由平均数性质,\(\sum x_{i} = n\overline{x}\),\(\sum y_{i} = n\overline{y}\),其中n为实验数据组数。将其带入(3)式,可得:

\[b = \overline{y} - a\overline{x}\; (4) \]

此式表明,线性回归函数必过点\((\overline{x}, \overline{y})\)

将(4)式带入(1)式,得:

\[M = \sum[y_{i}-(ax_{i}+\overline{y} - a\overline{x})]^{2} = \sum[(y_{i}-\overline{y})-a(x_{i} - \overline{x})]^{2}\; (5) \]

现对(5)式求偏导数,应用多元复合函数求导法则,推导(2)式:

\[\frac{\partial M}{\partial a} = -2\sum[(y_{i}-\overline{y})-a(x_{i} - \overline{x})](x_{i}-\overline{x}) = 0\; (6) \]

整理(6)式:

\[\sum(y_{i}-\overline{y})(x_{i}-\overline{x})-a\sum (x_{i} - \overline{x})^{2} = 0\; (7) \]

最后可得:

\[a = \frac{\sum(x_{i}-\overline{x})(y_{i}-\overline{y})}{\sum (x_{i} - \overline{x})^{2}}\; (8) \]

最后附上python代码实现最小二乘法:

import numpy as np import matplotlib.pyplot as plt x = np.array([1., 2., 3., 4., 5.]) y = np.array([1., 3., 2., 3., 5.]) x_mean = np.mean(x) y_mean = np.mean(y) num = 0.0 d = 0.0 for x_i, y_i in zip(x, y): num += (x_i - x_mean) * (y_i - y_mean) d += (x_i - x_mean) ** 2 a = num/d b = y_mean - a * x_mean print('a is %f' % a) print('b is %f' % b) y_hat = a * x + b plt.scatter(x, y) plt.plot(x, y_hat, color='g') plt.axis([0, 6, 0, 6]) plt.show()


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3