[PyTorch]即插即用的热力图生成

您所在的位置:网站首页 图像的特征图怎么画的 [PyTorch]即插即用的热力图生成

[PyTorch]即插即用的热力图生成

2024-07-10 15:14| 来源: 网络整理| 查看: 265

        先上张效果图,本来打算移植霹雳老师的使用Pytorch实现Grad-CAM并绘制热力图。但是看了下代码,需要骨干网络按照标准写法(即将特征层封装为features数组),而我写的网络图省事并没有进行封装,改造网络的代价又太大了,所以干脆直接重写一个。

一、生成热力图

        大致可以分为三步:①读取图片;②前向传递运算;③用特征向量生成特征图。而图片的resize图简单可以直接用transforms,后面反正也是直接resize回来的,并不会造成变形。

# 加载一个transforms用于变形,input_shape为预设的图像尺寸 transform = transforms.Compose([transforms.Resize((input_shape[0],input_shape[1])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]) image = Image.open(image_path) #image_path为文件路径 input_tensor = transform(image) #将图片转换为tensor类型 input_batch = input_tensor.unsqueeze(0) #为tensor添加batch维度 # 前向传递 model.eval() with torch.no_grad(): output = model(input_batch)

        使用特征图生成热力图的原理是:将该维度上所有的tensor进行叠加,然后将生成的矩阵变形回输入向量的尺寸

heatmap = torch.sum(output, dim=1) #所有通道求和 max_value = torch.max(heatmap) min_value = torch.min(heatmap) heatmap = (heatmap-min_value)/(max_value-min_value)*255 heatmap = heatmap.cpu().numpy().astype(np.uint8).transpose(1,2,0) # 提取热力图 heatmap = cv2.resize(heatmap, input_shape,interpolation=cv2.INTER_LINEAR) # 还原尺寸 # 将矩阵转换为image类 heatmap=cv2.applyColorMap(heatmap,cv2.COLORMAP_JET) heatimg = Image.fromarray(heatmap) 二、叠加原图

        直接使用plt进行叠加!

# 将热力图叠加到原图上 org_size = image.size heatimg = heatimg.resize(org_size) #将热力图变回输入图像的尺寸 plt.axis('off') plt.imshow(image) plt.imshow(heatimg, alpha=0.5) # alpha为热力图的透明度 # 显示叠加后的图形 plt.show() 三、总结

        这段代码和霹雳老师的Grad-CAM对比优劣都很明显,优点是代码比较简单。上可以通过插入前向传递的环境直接得到任何层的热力图。但缺点就是不能关注特定的类别,且生成的热力图也不是很美观。



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3