WandB使用笔记
最近看代码,发现代码中有wandb有关的内容,搜索了一下发现是一个模型训练工具,然后学习了一下,这里记录一下使用过程,方便以后查阅。
WandB使用笔记
- 登录WandB 并 创建团队
- 安装 WandB 并 登录
- 模型训练过程跟踪
- 模型版本管理
- 自动调参
- 不同的模型训练工具对比
- 参考资料
作者自注:之前训练模型一直使用的是Visdom,感觉非常好用,然后现在学习了一下WandB,发现先各有优劣。Visdom的曲线实时跟踪效果好,但是功能简单。WandB曲线实时跟踪效果差(可能是我的网的问题),但是功能强大,可以保存每次模型调优的参数,这样就不用手动再记录了;可以实现模型的版本管理,这样就可以随便改代码,不用担心改坏了;可以进行参数分析,这样就可以有目的的进行参数调优;可以进行自动调参,这样可在完成粗调制后进行局部的参数寻优。感觉以后两个可以同时使用,提高模型调优的效率
登录WandB 并 创建团队
点击下面的网站进入WandB:https://wandb.ai/site,然后点击界面中的 LOGIN
进行登录。
如下需要选择登录的方式,这里我选择的是 GitHub 。
完成登陆后进入如下初始界面,点击图片中红框中的内容,创建一个新的 team
。
之后进入如下界面,输入团队名称,并点击 Create team
,完成团队的创建。
团队创建成功后出现如下界面,选择是否把自己的 runs
更新到 team
,这里选择 Update
。
如此就完成了登录和团建创建过程!
如果想要删除创建的团队,则在主界面点击创建的团队,如下图所示:
进入团队后,点击 Team settings
,如下图所示:
接着滑动到最下面,点击 Delete team
:
接着需要你输入 团队的名称 进行删除,这里的逻辑跟GitHub删除项目一样。
安装 WandB 并 登录
使用 pip
安装 WandB:
pip install wandb
验证安装是否成功:
wandb --version
首次使用 WandB 时,需要登录账户:
wandb login
登录后,WandB 会提示输入 API 密钥。可以从 WandB 的 API 密钥页面 获取密钥,点击图片中的红框部分,复制密钥,然后粘贴到上图的 3 标识的地方,并点击回车,如此就完成了登录过程。
如果你之前已经登陆过了,则会出现如下的内容:
然后在终端输入如下的命令即可重新登录:
wandb login --relogin
模型训练过程跟踪
将如下代码复制到PyCharm中,进行实验。
import wandb
import torch
from torch import nn
import torchvision
from torchvision import transforms
import datetime
from argparse import Namespacedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Namespace(project_name='wandb_demo',batch_size=512,hidden_layer_width=64,dropout_p=0.1,lr=1e-4,optim_type='Adam',epochs=150,ckpt_path='checkpoint.pt'
)def create_dataloaders(config):transform = transforms.Compose([transforms.ToTensor()])ds_train = torchvision.datasets.MNIST(root="./mnist/", train=True, download=True, transform=transform)ds_val = torchvision.datasets.MNIST(root="./mnist/", train=False, download=True, transform=transform)ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))dl_train = torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True, drop_last=True)dl_val = torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, drop_last=True)return dl_train, dl_valdef create_net(config):net = nn.Sequential()net.add_module("conv1", nn.Conv2d(in_channels=1, out_channels=config.hidden_layer_width, kernel_size=3))net.add_module("pool1", nn.MaxPool2d(kernel_size=2, stride=2))net.add_module("conv2", nn.Conv2d(in_channels=config.hidden_layer_width,out_channels=config.hidden_layer_width, kernel_size=5))net.add_module("pool2", nn.MaxPool2d(kernel_size=2, stride=2))net.add_module("dropout", nn.Dropout2d(p=config.dropout_p))net.add_module("adaptive_pool", nn.AdaptiveMaxPool2d((1, 1)))net.add_module("flatten", nn.Flatten())net.add_module("linear1", nn.Linear(config.hidden_layer_width, config.hidden_layer_width))net.add_module("relu", nn.ReLU())net.add_module("linear2", nn.Linear(config.hidden_layer_width, 10))net.to(device)return netdef train_epoch(model, dl_train, optimizer):model.train()for step, batch in enumerate(dl_train):features, labels = batchfeatures, labels = features.to(device), labels.to(device)preds = model(features)loss = nn.CrossEntropyLoss()(preds, labels)loss.backward()optimizer.step()optimizer.zero_grad()return modeldef eval_epoch(model, dl_val):model.eval()accurate = 0num_elems = 0for batch in dl_val:features, labels = batchfeatures, labels = features.to(device), labels.to(device)with torch.no_grad():preds = model(features)predictions = preds.argmax(dim=-1)accurate_preds = (predictions == labels)num_elems += accurate_preds.shape[0]accurate += accurate_preds.long().sum()val_acc = accurate.item() / num_elemsreturn val_acc
def train(config=config):dl_train, dl_val = create_dataloaders(config)model = create_net(config);optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)# ======================================================================nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')wandb.init(project=config.project_name, config=config.__dict__, name=nowtime, save_code=True)model.run_id = wandb.run.id# ======================================================================model.best_metric = -1.0for epoch in range(1, config.epochs + 1):model = train_epoch(model, dl_train, optimizer)val_acc = eval_epoch(model, dl_val)if val_acc > model.best_metric:model.best_metric = val_acctorch.save(model.state_dict(), config.ckpt_path)nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")# ======================================================================wandb.log({'epoch': epoch, 'val_acc': val_acc, 'best_val_acc': model.best_metric})# ======================================================================# ======================================================================wandb.finish()# ======================================================================return model
上述代码最关键的就是如下三个部分:
- 初始化部分:
wandb.init(project=config.project_name, config=config.__dict__, name=nowtime, save_code=True)
- 模型训练参数上传
wandb.log({'epoch': epoch, 'val_acc': val_acc, 'best_val_acc': model.best_metric})
- 模型训练完成关闭wandb:
wandb.finish()
最后在PyCharm中输入如下代码,即可运行上述代码:
model = train(config)
代码运行成功,即可出现如下的界面,点击下图中红框中的部分,即可跳转到曲线监视界面。
模型训练过程监视界面如下图所示:
点击下图中的红框部分,更改曲线的横坐标值。
如下图所示,将横坐标值更改为 epoch。
然后我们还可以增加一个 section
。
在新的 section
中添加新的显示模块,如下图所示:
此处我们添加了验证集的准确率,实现实时的监控。
模型训练结束,我们可以点击 runs
查看历史记录。
如下图可以看到,我们刚才监视的曲线,如图中的长方形红框所示。然后点击小红框中的 runs
,查看每一次训练过程的模型参数。
每一次模型训练的参数如下图所示,可以选择图中红框中的内容,选择需要的参数进行显示。
可选择的指标如下图所示:
对于某些我们比较关注的指标,我们可以将其固定显示:
固定后,我们回到 Workspace
界面,即可看到固定的参数。
模型版本管理
除了可以记录实验日志传递到 wandb 网站的云端服务器 并进行可视化分析。wandb还能够将实验关联的数据集,代码和模型 保存到 wandb 服务器。我们可以通过 wandb.log_artifact的方法来保存任务的关联的重要成果。例如 dataset, code,和 model,并进行版本管理。
当我们跑出一个相对不错的结果时,我们希望把这个结果给保存下来,此时我们就可以使用该功能。
我们先使用run_id 恢复 run任务,以便继续记录。
import wandb
# resume the run
run = wandb.init(project='wandb_demo', id='6h5xkv16', resume='allow')
上述代码中的 id
是用来关联我们训练的 runs
的,参数的值来自下图红框中的内容,想搞关联某一次的训练过程,就把某一次训练的 ID
写入上述代码。
保存数据集的代码:
# save dataset
arti_dataset = wandb.Artifact(name='mnist', type='dataset')
arti_dataset.add_dir('mnist/')
wandb.log_artifact(arti_dataset)
保存模型文件的代码:
# save code
arti_code = wandb.Artifact(name='py', type='code')
arti_code.add_file('./wandb_test.py')
wandb.log_artifact(arti_code)
保存模型权重的代码:
# save model
arti_model = wandb.Artifact(name='cnn', type='model')
arti_model.add_file(config.ckpt_path)
wandb.log_artifact(arti_model)
最后结束时要使用一下代码:
# finish时会提交保存
wandb.finish()
上传后的效果如图所示:
自动调参
sweep采用类似master-workers的controller-agents架构,controller在wandb的服务器机器上运行,agents在用户机器上运行,controller和agents之间通过互联网进行通信。同时启动多个agents即可轻松实现分布式超参搜索。
使用Sweep的3步骤:
- 配置 sweep_config
# 配置 Sweep config
sweep_config = {'method': 'random', # 选择调优算法,超参数搜索方法:随机搜索'metric': { # 定义调优目标'name': 'val_acc','goal': 'maximize'},'parameters': { # 定义超参空间'project_name': {'value': 'wandb_demo'}, # 固定不变的超参'epochs': {'value': 10},'ckpt_path': {'value': 'checkpoint.pt'},'optim_type': { # 离散型分布超参'values': ['Adam', 'SGD', 'AdamW']},'hidden_layer_width': {'values': [16, 32, 48, 64, 80, 96, 112, 128]},'lr': { # 连续型分布超参'distribution': 'log_uniform_values','min': 1e-6,'max': 0.1},'batch_size': {'distribution': 'q_uniform','q': 8,'min': 32,'max': 256,},'dropout_p': {'distribution': 'uniform','min': 0,'max': 0.6,}},# 'early_terminate': { # 定义剪枝策略 (可选)# 'type': 'hyperband', # 使用 HyperBand 作为早停策略# 'min_iter': 3, # 最小评估迭代次数(第 3 次迭代后开始考虑剪枝)# 'eta': 2, # 成倍增长的资源分配比例(每次迭代中仅保留约 1/eta 的实验)# 's': 3 # HyperBand 的最大阶数,影响资源分配的层级# }
}
from pprint import pprint
pprint(sweep_config)
Sweep支持如下3种调优算法:
(1)网格搜索:grid. 遍历所有可能得超参组合,只在超参空间不大的时候使用,否则会非常慢。
(2)随机搜索:random. 每个超参数都选择一个随机值,非常有效,一般情况下建议使用。
(3)贝叶斯搜索:bayes.
创建一个概率模型估计不同超参数组合的效果,采样有更高概率提升优化目标的超参数组合。对连续型的超参数特别有效,但扩展到非常高维度的超参数时效果不好。
- 初始化 sweep controller
# 初始化 sweep controller
sweep_id = wandb.sweep(sweep_config, project=config.project_name)
- 启动 sweep agents
# 启动 Sweep agent
# 该agent 随机搜索 尝试5次
wandb.agent(sweep_id, train, count=5)
等代码跑完我们就有了一个 sweep
,如下图所示:
进入 sweep
之后就可以添加 Parallel coordinates
和 Parameter importance
进行参数分析。
不同的模型训练工具对比
工具 | 实验管理 | 数据版本控制 | 模型部署 | 团队协作 | 离线支持 | 特点 |
---|---|---|---|---|---|---|
TensorBoard | ✅ | ❌ | ❌ | ❌ | ✅ | 轻量级工具,适合快速原型开发 |
WandB | ✅ | ✅ | ✅ | ✅ | ✅ | 功能全面,支持超参数调优和实时协作 |
Comet | ✅ | ❌ | ❌ | ✅ | ✅ | 简单易用,支持离线模式 |
MLflow | ✅ | ✅ | ✅ | ✅ | ✅ | 实验管理与模型部署一体化 |
Neptune | ✅ | ❌ | ❌ | ✅ | ❌ | 强大的可视化功能 |
Sacred | ✅ | ❌ | ❌ | ❌ | ✅ | 极简实验管理工具 |
Polyaxon | ✅ | ✅ | ✅ | ✅ | ❌ | 分布式训练与大规模实验管理支持 |
DVC | ✅ | ✅ | ❌ | ❌ | ✅ | 专注于数据和模型版本控制 |
ClearML | ✅ | ✅ | ✅ | ✅ | ✅ | 全面的 MLOps 功能 |
参考资料
30分钟吃掉wandb模型训练可视化
wandb我最爱的炼丹伴侣操作指南
30分钟吃掉wandb可视化自动调参
wandb可视化调参完全指南