ConvLSTM时空预测实战代码详解 |
您所在的位置:网站首页 › 天气预报图例详解 › ConvLSTM时空预测实战代码详解 |
写在前面
时空预测是很多领域都存在的问题,不同于时间序列,时空预测不仅需要探究时间的变化,也需要关注空间的变化。许多预测问题都只片面的关注时间问题,如预测某人未来3年患某种病的概率,食堂就餐人数等,往往忽视了空间问题,如作为决策者,我不仅想知道明天患新冠的人数,而且想知道这些人会在哪些位置发病,以便精准管理。换句话说,决策者更多关注的是人群层面,而干预和施控是工作人员主要关注的问题。空间问题能够解答众多人对于预测问题最后的疑问,既:事件A可能会发生,那它会在哪里发生? 近些年关于时空预测的问题不断刺激着机器学习领域的科学家和学者,有别于传统统计学时空预测模型,机器学习或深度学习已经展现了其强大的优势:非线性拟合、高维数据的处理能力、较少担心变量共线性对模型的影响等。目前,通过长短期记忆网络LSTM和卷积神经网络CNN分别提取时间和空间特征来实现时空预测已经成为了可能,本文主要基于2015年发表在arxiv平台上的一篇building block文章《Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting》来实现时空预测的经典问题——下一帧视频预测,当然,其他机器学习时空预测模型和时空预测问题不在本文讨论范围内,有兴趣的同学可以自行尝试。 一、时空预测很多问题都可以尝试时空预测,如新冠疫情发病,交通流量、天气预报、慢性病区域发病风险预测等,空间可以被时间切分为无数个2维图片,这样用CNN来提取图片中的信息,用LSTM来处理时序问题,便能很好的解决时空预测问题,如下图所示。注意,与普通的LSTM不同,时空预测在其他维度上仍有位置变化,这是空间的体现。 和论文中一样,我们也选取Moving-MNIST数据集,既移动MNIST数据集,该公开数据集可以在多伦多大学提供的网站上下载,Moving-MNIST数据集是时空预测常用的数据集之一,数据集下载代码如下: import numpy as np from tensorflow import keras fpath = keras.utils.get_file( "moving_mnist.npy", "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy", ) dataset = np.load(fpath) print(dataset.shape)下载好数据集后,我们输出数据集的shape,结果为(20, 10000, 64, 64),表示一个seq有二十个图片,前十帧为input,后十帧为target,一共有10000个sequence,每个图片的大小为64✖64,如下图所示: 由于我们的ConvLSTM接受的输入是(sanmples,seq,wide,height,channel),因此我们要把数据集进行改造以符合模型的输入要求。 # 转换数据集的seq和samples维度,便于输入我们的模型 dataset = np.swapaxes(dataset, 0, 1) # 10000个样本太多,我们只选取1000个 dataset = dataset[:1000, ...] # 我们此时是二维灰度图片,因此要增加一维,代表单通道,如果是彩色,则为3 dataset = np.expand_dims(dataset, axis=-1) print(dataset.shape)转换后数据集的shape为(1000,20,64,64,1)已经符合模型的输入要求,接下来是划分数据集,这里要打乱索引,实现随机划分训练集和测试集。 indexes = np.arange(dataset.shape[0]) np.random.shuffle(indexes) # 打乱索引顺序 # 训练集:测试集=9:1 train_index = indexes[: int(0.9 * dataset.shape[0])] val_index = indexes[int(0.9 * dataset.shape[0]):] train_dataset = dataset[train_index] val_dataset = dataset[val_index] print(train_dataset.shape) print(val_dataset.shape)划分好数据集后我们通过除以255实现归一化,归一化要在划分数据集之后完成,不然会导致数据泄露。 # 归一化,除255就是把3基色都调到0-1区间,得到绝对色彩信息 train_dataset = train_dataset / 255 val_dataset = val_dataset / 255分离x和y,按照论文,我们是前20帧预测后20帧,类似于下图: 模型结构如下: 构建好模型后,我们开始训练,本人电脑显卡为GTX1660ti ,显存6G,显存有限,因此把batch size调小一些,并适当增大一些epoch,如果算力允许,可以增大epoch和batch size。这里定义了早期终止训练和调整学习率回调函数,因此无需担心无效训练时间增加问题。 # 定义回调 early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10) reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=5) # 设置训练参数 epochs = 50 batch_size = 2 # 拟合模型. model.fit( x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_val, y_val), callbacks=[early_stopping, reduce_lr], ) model.save('model.h5') 六、结果查看与自然语言处理不同,时空预测不能单纯对比accuracy、F1-score这些指标,要对比预测结果和实际的差别。从结果来看,我们预测结果比较模糊,因此,考虑改用adam优化和调整卷积核的个数至30个以及epoch改为15,结果如下: 目前来讲,空间统计学在医疗领域的运用是较为普遍的,特别是传染病和环境暴露因素的研究中,时常能看到论文作者使用时空统计模型,而ConLSTM的诞生和近些年来的改进能够推动基于机器学习的时空预测模型在传染病和环境暴露因素研究中的应用,除此之外,ConvLSTM甚至能够分析患者的视频数据,如步态或者某区域慢性病风险矩阵图等,这些领域依旧空白,我相信,AI的助力能够推动这些领域的蓬勃发展。 |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |