深度学习·wandb

news/2024/12/22 9:33:00/

wandb

一个好用的可视化训练过程和调参工具,建议在深度学习中使用,语法来说更加方便

前置工作

这里是一些简单的网络结构,用于测试

数据集:

  • Kaggle上HeartDisease的0-1分类问题
    df=pd.read_csv('../data/heart_attack/heart.csv')

数据集的迭代:

  • X=torch.tensor(X.values,device=config.device,dtype=torch.float32) y=torch.tensor(y.values,device=config.device,dtype=torch.float32).reshape(-1,1) dataset=TensorDataset(X,y) dataloader=DataLoader(dataset,batch_size=config.batch_size,shuffle=True)

简单的DNN

class DNN(nn.Module):def __init__(self,input_size,hidden_size,dropout:float):super().__init__()self.input_size=input_sizeself.hidden_size=hidden_sizeself.fc1=nn.Linear(self.input_size,self.hidden_size)self.fc2=nn.Linear(self.hidden_size,self.hidden_size)self.fc3=nn.Linear(self.hidden_size,1)self.dropout=nn.Dropout(dropout)def forward(self,x):x=F.leaky_relu(self.fc1(x))x=self.dropout(x)x=F.leaky_relu(self.fc2(x))x=self.dropout(x)x=self.fc3(x)return x

wandb监视训练过程

使用login()登陆

import os
os.environ["WANDB_API_KEY"] = "xxxx"
wandb.login(key=os.environ['WANDB_API_KEY'])

初始化wandb

  • 建议使用系统时间:
    current_time = datetime.now()standard_time = current_time.strftime("%Y-%m-%d %H:%M:%S")name=standard_time
  • 初始化:
    注意保存wand.run.id方便继续监视该模型
    wandb.init(project=config.project_name,name=name,config=config.__dict__)# 转换为dict    model_run_id=wandb.run.id

训练流程中记录参数

    for epoch in tqdm(range(config.epochs)):for X,y in dataloader:# 反向传播# 评估指标wandb.log({'epoch':epoch+1,'val_acc':val_acc,'best_acc':best_metric})wandb.finish()

wandb.log从接口收到对应参数,wandb.finish()完成记录,主要不要漏掉finish

继续训练

  • 提供run.id并将resume设置为must
    wandb.init(project=config.project_name,id=model_run_id,resume='must')

Artifact工件

工件可以是代码也可以是数据集
第一个参数是名称,第二个是类型

wandb.init(project=config.project_name,id=model_run_id,resume='must')
arti_dataset=wandb.Artifact('HeartDisease',type='dataset')
arti_dataset.add_dir('../data/heart_attack/')
wandb.log_artifact(arti_dataset)
```python
arti_code=wandb.Artifact('ipynb',type='code')
arti_code.add_file('./wand_test.ipynb')
wandb.log_artifact(arti_code)
wandb.finish()

Table

可视化分析

wandb.init(project=config.project_name,id=model_run_id,resume='must')
good_cases=wandb.Table(columns=['id','GroundTrue','Prediction'])
bad_cases=wandb.Table(columns=['id','GroundTrue','Prediction'])

在代码中加入如下:

good_cases.add_data(i,y,prediction)

一般是用于比对feature、label和prediction


http://www.ppmy.cn/news/1533340.html

相关文章

车辆重识别(2021NIPS在图像合成方面,扩散模型打败了gans网络)论文阅读2024/10/01

本文在架构方面的创新: ①增加注意头数量: 使用32⇥32、16⇥16和8⇥8分辨率的注意力,而不是只使用16⇥16 ②使用BigGAN残差块 使用Big GAN残差块对激活进行上采样和下采样 ③自适应组归一化层 将经过组归一化操作后的时间步和类嵌入到每…

CaChe的基本原理

目录 一、Cache的定义与结构 二、Cache的工作原理 三、Cache的映射与替换策略 四、Cache的写操作处理 Cache,即高速缓冲存储器,是计算机系统中位于CPU与主存之间的一种高速存储设备。它的主要作用是提高CPU对存储器的访问速度,从而优化系…

最大正方形 Python题解

最大正方形 题目描述 在一个 n m n\times m nm 的只包含 0 0 0 和 1 1 1 的矩阵里找出一个不包含 0 0 0 的最大正方形,输出边长。 输入格式 输入文件第一行为两个整数 n , m ( 1 ≤ n , m ≤ 100 ) n,m(1\leq n,m\leq 100) n,m(1≤n,m≤100),接…

APO v0.5.0 发布:可视化配置告警规则;优化时间筛选器;支持自建的ClickHouse和VictoriaMetrics

APO 新版本 v0.5.0 正式发布!本次更新主要包含以下内容: 新增页面配置告警规则和通知 在之前的版本中,APO 平台仅支持展示配置文件中的告警规则,若用户需要添加或调整这些规则,必须手动编辑配置文件。而在新版本中&a…

docker安装kafka-manager

kafkamanager docker安装_mob64ca12d80f3a的技术博客_51CTO博客 # 1、拉取镜像及创建容器 docker pull hlebalbau/kafka-manager docker run -d --name kafka-manager -p 9000:9000 --networkhost hlebalbau/kafka-manager# 2、增设端口 腾讯云# 3、修改防火墙 sudo firewall-…

LeetCode热题100速通

一丶哈希 1、两数之和(简单) 给定一个整数数组 n u m s nums nums 和一个整数目标值 t a r g e t target target,请你在该数组中找出 和为目标值 t a r g e t target target 的那 两个 整数,并返回它们的数组下标。 你可以假设…

uni-app在线预览pdf

这里推荐下载pdf.js 插件 PDF.js - Browse Files at SourceForge.net 特此注意 如果报 Promise.withResolvers is not a function 请去查看版本兼容问题 降低pdf.js版本提高node版本 下载完成后 在 static 文件夹下新建 pdf 文件夹,将解压文件放进 pdf 文件…

十三、减少磁盘延迟时间的方法

1.交替编号 让逻辑上相邻的扇区在物理上不相邻; 原因:由于磁头在读取完一个扇区之后需要等待一段时间才能再次读入下一个扇区,如果逻辑上相邻的扇区在物理上相邻的话,需要等待磁盘转完一圈才能读取到。 2.错位命名 让相邻盘面上…