Pytorch小技巧:布尔类型(True/False)转换为浮点类型(1.0/0.0)的成长史 |
您所在的位置:网站首页 › 1的布尔值不是false › Pytorch小技巧:布尔类型(True/False)转换为浮点类型(1.0/0.0)的成长史 |
文章目录
引言方法1:基于列表生成式方法2:基于torch.where()方法3:强制类型转换比较三种方法结束语
引言
在深度学习和PyTorch中,数据类型的处理是一个关键环节。我们经常需要应对各种数据类型,包括布尔值(True/False)和浮点数(1.0/0.0)。有时,为了满足特定的计算或操作需求,我们需要将布尔类型的tensor转换为浮点类型或整数类型的tensor。随着编程经验的积累,我们处理此类需求的方式也会不断优化。以下是我不同阶段处理此类需求的代码缩影。 方法1:基于列表生成式我们可以通过列表生成式 + if-else语句实现将布尔类型的tensor转换为浮点类型/整数类型的tensor。 完整代码 import time import torch bool_tensor = torch.tensor([True, False, True]) start_time = time.time() # 统计10w次,比较三种代码的时间复杂度 for _ in range(100000): # 列表生成式 float_tensor = torch.tensor([1.0 if value else 0.0 for value in bool_tensor]) end_time = time.time() print(float_tensor, "cost_time: ", end_time - start_time)运行结果
我们可以通过torch.where()函数实现将布尔类型的tensor转换为浮点类型/整数类型的tensor。 完整代码 import time import torch bool_tensor = torch.tensor([True, False, True]) start_time = time.time() # 统计10w次,比较三种代码的时间复杂度 for _ in range(100000): # torch.where float_tensor = torch.where(bool_tensor, 1.0, 0.0) end_time = time.time() print(float_tensor, "cost_time: ", end_time - start_time)运行结果 可以看出,使用torch.where()处理10w次从布尔类型的tensor转换为浮点类型/整数类型的tensor需求,需要大约0.78s。 方法3:强制类型转换在某次偶然的尝试下,我发现可以通过强制类型转换实现将布尔类型的tensor转换为浮点类型/整数类型的tensor。 完整代码 import time import torch bool_tensor = torch.tensor([True, False, True]) start_time = time.time() # 统计10w次,比较三种代码的时间复杂度 for _ in range(100000): # 类型转换 float_tensor = bool_tensor.float() end_time = time.time() print(float_tensor, "cost_time: ", end_time - start_time)运行结果
一言以蔽之:强制类型转换真香!!! 结束语 亲爱的读者,感谢您花时间阅读我们的博客。我们非常重视您的反馈和意见,因此在这里鼓励您对我们的博客进行评论。您的建议和看法对我们来说非常重要,这有助于我们更好地了解您的需求,并提供更高质量的内容和服务。无论您是喜欢我们的博客还是对其有任何疑问或建议,我们都非常期待您的留言。让我们一起互动,共同进步!谢谢您的支持和参与!我会坚持不懈地创作,并持续优化博文质量,为您提供更好的阅读体验。谢谢您的阅读! |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |