PyTorch指标计算库TorchMetrics详解

您所在的位置:网站首页 两向量夹角的正弦值是什么 PyTorch指标计算库TorchMetrics详解

PyTorch指标计算库TorchMetrics详解

2023-05-21 02:03| 来源: 网络整理| 查看: 265

什么是指标

弄清楚需要评估哪些指标(metrics)是深度学习的关键。有各种指标,我们就可以评估ML算法的性能。一般来说,指标(metrics)的目的是监控和量化训练过程。在一些技术中,如学习率调度learning-rate scheduling或提前停止early stopping,指标是用来调度和控制的关键。虽然也可以在这里使用损失loss,但指标是首选,因为它们能更好地代表训练目标。与损失相反,指标不需要是可微的(事实上很多都不是),但其中一些是可微的。如果指标本身是可微的,并且它是基于纯PyTorch实现,那么它也跟损失一样可以用来进行反向传播。

简介

TorchMetrics对80多个PyTorch指标进行了代码实现,且其提供了一个易于使用的API来创建自定义指标。对于这些已实现的指标,如准确率Accuracy、召回率Recall、精确度Precision、AUROC、RMSE、R²等,可以开箱即用;对于尚未实现的指标,也可以轻松创建自定义指标。主要特点有:

一个标准化的接口,以提高可重复性兼容分布式训练经过了严格的测试在批次batch之间自动累积在多个设备之间自动同步 安装

使用pip:

pip install torchmetrics 或使用 conda: 1conda install -c conda-forge torchmetrics

使用

与torch.nn类似,大多数指标都有一个基于类的版本和一个基于函数的版本。

函数版本

函数版本的指标实现了计算每个度量所需的基本操作。它们是简单的python函数,接收torch.tensors作为输入,然后返回torch.tensor类型的相对应的指标。一个简单的示例如下:

123456789import torch# import our libraryimport torchmetrics# simulate a classification problempreds = torch.randn(10, 5).softmax(dim=-1)target = torch.randint(5, (10,))acc = torchmetrics.functional.accuracy(preds, target)

模块版本

几乎所有的函数版本的指标都有一个相应的基于类的版本,该版本在实际代码中调用对应的函数版本。基于类的指标的特点是具有一个或多个内部状态(类似于PyTorch模块的参数),使其能够提供额外的功能:

对多个批次的数据进行累积多个设备之间的自动同步指标运算

一个示例如下:

12345678910111213141516171819202122import torch# import our libraryimport torchmetrics# initialize metricmetric = torchmetrics.Accuracy()n_batches = 10for i in range(n_batches): # simulate a classification problem preds = torch.randn(10, 5).softmax(dim=-1) target = torch.randint(5, (10,)) # metric on current batch acc = metric(preds, target) print(f"Accuracy on batch {i}: {acc}")# metric on all batches using custom accumulationacc = metric.compute()print(f"Accuracy on all data: {acc}")# Reseting internal state such that metric ready for new datametric.reset() 每次调用指标的前向计算时,一方面对当前看到的一个批次的数据进行指标计算,另一方面更新内部指标状态,该状态记录了当前看到的所有数据。内部状态需要在 epoch之间被重置,并且不应该在训练、验证和测试之间混淆。因此,强烈建议按不同的模式重新初始化指标,如下例所示: 1234567891011121314151617181920212223242526272829from torchmetrics.classification import Accuracytrain_accuracy = Accuracy()valid_accuracy = Accuracy()for epoch in range(epochs): for x, y in train_data: y_hat = model(x) # training step accuracy batch_acc = train_accuracy(y_hat, y) print(f"Accuracy of batch{i} is {batch_acc}") for x, y in valid_data: y_hat = model(x) valid_accuracy.update(y_hat, y) # total accuracy over all training batches total_train_accuracy = train_accuracy.compute() # total accuracy over all validation batches total_valid_accuracy = valid_accuracy.compute() print(f"Training acc for epoch {epoch}: {total_train_accuracy}") print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}") # Reset metric states after each epoch train_accuracy.reset() valid_accuracy.reset()

自定义指标

如果想使用一个尚不支持的指标,可以使用TorchMetrics的API来实现自定义指标,只需将torchmetrics.Metric子类化并实现以下方法:

实现__init__方法,在这里为每一个指标计算所需的内部状态调用self.add_state;实现update方法,在这里进行更新指标状态所需的逻辑;实现compute方法,在这里进行最终的指标计算。 RMSE例子

以均方根误差(RMSE, Root mean squared error)为例,来看怎样自定义指标。均方根误差的计算公式为:

RMSE=1N∑n=1N(y^i−yi)2

为了正确计算RMSE,我们需要两个指标状态:sum_squared_error来跟踪目标y^和预测y之间的平方误差;n_observations来统计我们进行了多少次观测。

123456789101112131415161718from torchmetrics.metric import Metricclass MeanSquaredError(Metric): def __init__(self): super().__init__() # 添加状态,dist_reduce_fx指定了用来在多进程之间聚合状态所用的函数 self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("n_observations", default=tensor(0), dist_reduce_fx="sum") def update(self, preds, target): # 更新状态 self.sum_squared_error += torch.sum((preds-target)**2) self.n_observations += preds.numel() def compute(self): """Computes mean squared error over state.""" return torch.sqrt(self.sum_squared_error/self.n_observations)

关于实现自定义指标的实际例子和更多信息,看这个页面。

指标运算

TorchMetrics支持大多数Python内置的算术、逻辑和位操作的运算符。比如:

1234first_metric = MyFirstMetric()second_metric = MySecondMetric()new_metric = first_metric + second_metric 这种运算模式可以适用于以下运算符( a是指标, b可以是指标、张量、整数或浮点数):

加法(a + b)按位与(a & b)等价(a == b)向下取整除floor division (a // b)大于等于 (a >= b)大于 (a > b)小于等于 (a 'map': tensor(0.6000),>>> 'map_50': tensor(1.),>>> 'map_75': tensor(1.),>>> 'map_large': tensor(0.6000),>>> 'map_medium': tensor(-1.),>>> 'map_per_class': tensor(-1.),>>> 'map_small': tensor(-1.),>>> 'mar_1': tensor(0.6000),>>> 'mar_10': tensor(0.6000),>>> 'mar_100': tensor(0.6000),>>> 'mar_100_per_class': tensor(-1.),>>> 'mar_large': tensor(0.6000),>>> 'mar_medium': tensor(-1.),>>> 'mar_small': tensor(-1.)}

Donate

WeChat Pay

# PyTorch Neovim预配置库NvChad探索 Pandas可视化数据分析工具D-Tale详解


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3