python + matplotlib 实现线性回归3d图像展示

您所在的位置:网站首页 线性回归模型图 python + matplotlib 实现线性回归3d图像展示

python + matplotlib 实现线性回归3d图像展示

2024-06-04 16:24| 来源: 网络整理| 查看: 265

初学线性回归,在网上搜索了一圈没找到一个满意的可视化图像,决定自己手写一个。

数学模型: y = W1 * x1 + W2 * x2 + b

损失函数:均方误差

真实结果:y = 2 * x1 + 3 * x2 + 4.2

目标:绘制 W1和W2取不同值时的损失函数图像

import torch import matplotlib.pyplot as plt def synthetic_data(w, b, num_examples): # @save """生成y=Xw+b+噪声 : 期望,标准差,size(样本量,权重数)""" X = torch.normal(0, 1, (num_examples, len(w))) y = torch.matmul(X, w) + b y += torch.normal(0, 0.01, y.shape) return X, y.reshape((-1, 1)) def linreg(X, w, b): """线性回归模型""" return torch.matmul(X, w) + b def squared_loss(y_hat, y): #@save """均方损失""" return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2 # 根据x,y坐标(W1, W2)返回z值 def get_z(w1, w2): loss = squared_loss net = linreg return loss(net(features, torch.tensor([w1, w2]), true_b), labels).mean() true_w = torch.tensor([2., 3.]) true_b = 4.2 features, labels = synthetic_data(true_w, true_b, 1000) x = torch.arange(0, 4, 0.1) y = torch.arange(1, 5, 0.1) xs, ys = torch.meshgrid(x, y, indexing='ij') zs = torch.zeros(xs.shape) for i in range(len(x)): for j in range(len(y)): zs[i, j] = get_z(xs[i][j], ys[i][j]) fig = plt.figure(figsize=(10,8)) ax = fig.add_subplot(111, projection='3d') surf = ax.plot_surface(xs, ys, zs, cmap='coolwarm') fig.colorbar(surf) ax.set_xlabel('W1') ax.set_ylabel('W2') ax.set_zlabel('Loss') plt.show()



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3