Learning Deep Features for Discriminative Localization

server/2024/12/22 18:41:53/

 1、引言

        论文链接:https://arxiv.org/abs/1512.04150

        Bolei Zhou[1] 等重新审视了 GAP(Gobal Average Pooling),并阐明了它如何明确地使卷积神经网络具有显著的定位能力,同时提出 CAM(Class Activation Maps)[1] 技术来可视化这种能力。CAM 允许我们可视化任何给定图像上的预测类分数,突出显示 CNN(Convolutional Neural Networks) 检测到的判别对象部分。CAM 可以产生通用的定位深度特征来帮助其他研究人员了解CNN 对其任务使用的判别基础。

2、方法

图1  CNN for classification

        如图 1 所示,GAP 输出最后一个卷积层每个单元的特征图的空间平均值。这些值的加权和用于生成最终输出。类似地,我们可以计算最后一个卷积层的特征图的加权和以获得 CAM,细节如图 2 所示,即若想计算给定图片在一个类别的上的 CAM,只需取出全连接层对应类别的参数 w1、w2、...、wn,易知每个参数 w 对应于最后一个卷积层输出的一个单元的特征图,即 n 也是最后一个卷积层的输出通道数,计算最后一个卷积层的特征图的加权和就可以获得 CAM,每个特征图的权重就是对应的 w

图2  CAM

        为了更直观地观察 CAM,一般还需要经过以下步骤才能得到如图 3  所示的效果:

        (1)归一化后映射到0-255。

        (2)上采样到原图大小。

        (3)获得 heatmap。

        (4)计算待展示结果 result=0.6*heatmap+0.4*original_image。

图3  top 5 预测类别的 CAM 示例

3、总结

        [1] 提出了用于具有 GAP 的 CNN。这使得分类训练的 CNN 能够学习执行对象定位,而无需使用任何边界框注释。CAM 允许我们可视化任何给定图像上的预测类分数,突出显示 CNN 检测到的判别对象部分,有助于理解和分析神经网络的工作原理及决策过程,进而去更好地选择或设计网络。我们还可以利用可视化的信息引导网络更好的学习,例如可以利用 CAM 信息通过"擦除"或""裁剪""的方式对数据进行增强。

        作者开源的代码在:GitHub - zhoubolei/CAM: Class Activation Mapping,Pytorch 实现的最后一次更新的时间为 2021 年 6 月 30 日,故使用的 Pytorch 版本较老,本人重写了一遍如下所示:

import cv2
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torchvision.models.feature_extraction import create_feature_extractordef showCAM(img_path, model, stage_name, class_dict, transform):"""展示 model 预测 img 概率最高的 5 个类别的 CAM:param img_path: 待展示 CAM 的图片路径:param model::param stage_name: model 的最后一个 stage 名称:param class_dict: 数据集字典列表,每个字典的形式为 {class_id, class_name}:param transform: 预处理 model 的输入图片:return:"""model.eval()# 全连接层的权重last_layer = list(model.modules())[-1]fc_weights = last_layer.weightoriginal_img = Image.open(img_path)# softmax计算概率img = transform(original_img).unsqueeze(0)output = model(img)psort = torch.sort(F.softmax(output, dim=1), descending=True)prob, cls_idx = psort# top5的类别和概率top5 = [(i.item(), j.item()) for i, j in zip(cls_idx.view(-1), prob.view(-1))][:5]fig, axs = plt.subplots(2, 3)axs.reshape(-1)[0].imshow(np.asarray(original_img))for idx, cls_prob in enumerate(top5):# 获取对应类别的权重cls_weights = fc_weights[cls_prob[0]].detach().unsqueeze(0)  # 1, class_num# 特征图提取feature_extractor = create_feature_extractor(model, return_nodes={stage_name: "feature_map"})forward = feature_extractor(img)b, c, h, w = forward["feature_map"].shapefeature_map = forward["feature_map"].detach().reshape(c, h * w)# 激活类别特征映射CAM = torch.mm(cls_weights, feature_map).reshape(h, w)# 归一化后映射到0-255CAM = (CAM - torch.min(CAM)) / (torch.max(CAM) - torch.min(CAM))CAM = (CAM.numpy() * 255).astype("uint8")# 上采样到原图大小upsample = cv2.resize(CAM, original_img.size)# 热力图heatmap = cv2.applyColorMap(upsample, cv2.COLORMAP_JET)heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)result = heatmap * 0.6 + np.asarray(original_img) * 0.4# result = heatmapaxs.reshape(-1)[idx + 1].imshow(np.uint8(result))axs.reshape(-1)[idx + 1].text(-10, -10, f"{class_dict[cls_prob[0]]}: {cls_prob[1]:.3f}", fontsize=12,color="black")plt.show()

参考文献

[1] Bolei ZhouAditya KhoslaAgata Lapedriza, Aude Oliva, and Antonio TorralbaLearning Deep Features for Discriminative Localization. In CVPR, 2016.


http://www.ppmy.cn/server/104916.html

相关文章

HTML详解

1. 文档结构标签 <!DOCTYPE html>&#xff1a;声明文档类型&#xff0c;告诉浏览器这是一个HTML5文档。<html>&#xff1a;HTML文档的根元素&#xff0c;包含整个HTML文档。<head>&#xff1a;包含文档的元数据&#xff08;metadata&#xff09;&#xff0c…

【全网行为管理解决方案】上网行为系统有哪些?

全网行为管理系统是一种用于监控、管理和优化企业内部网络中所有用户活动及网络流量的技术解决方案。 这类系统可以帮助企业提高网络安全、优化网络性能&#xff0c;并确保网络使用符合公司政策及法规要求。以下是几种常用的上网行为管理系统&#xff1a; 一、安企神 特点&am…

HubSpot 自动化营销平台助力出海企业精准获客与转化 | 自动化营销

HubSpot 提供了多个开源 cms 和一体化且全面的解决方案&#xff0c;可帮助出海企业优化内容营销策略 HubSpot 自动化营销加速国际化 随着全球化的推进&#xff0c;越来越多的企业开始寻求拓展国际市场&#xff0c;而在这个过程中&#xff0c;有效的客户关系管理和营销自动化成…

react redux异步请求

1,创建store //store/modules/channelStore.js import { createSlice } from "reduxjs/toolkit" import axios from "axios"const channelStore createSlice({name: channel,initialState: {channelList: []},reducers: {setChannels (state, action) {s…

javascript利用三元运算符制作补零程序

这里的补零是当数字小于0时自动在前面补零&#xff0c;大于等于10时&#xff0c;前面不用补零。 代码如下 <html><head><meta charset"UTF-8"><title></title></head><body><script>let numprompt("请输入一…

leetcode 3 无重复字符的最长子串

leetcode 3 无重复字符的最长子串 正文普通解法双指针 正文 普通解法 重点观察示例 3。本题重点是创建一个动态区间&#xff0c;然后判断位于这个动态区间之外的字符是否被包含在这个动态区间范围内。并且对于 s 长度小于 1 的情况要重点进行讨论。 class Solution:def lengt…

【React原理 - 任务调度和时间分片详解】

概述 在React15的时候&#xff0c;React使用的是从根节点往下递归的方式同步创建虚拟Dom&#xff0c;由于递归具有同步不可中断的特性&#xff0c;所以当执行长任务时(通常以60帧为标准&#xff0c;即16.6ms)就会长时间占用主线程长时间无响应&#xff0c;导致页面卡顿&#x…

.Net 6 WebApi项目中使用Log4Net详解

众所周知as we know&#xff0c; log4Net是一个很方便的日志输出工具&#xff0c;但是&#xff0c;每次使用&#xff0c;日志都没有顺利输出过.....各种不知名问题.......所以就记录一下&#xff0c;方便下次使用。 具体的与原理和基础在此不做赘述&#xff0c;咱直接上干货&a…