一个简单的图像分类项目(三)编写脚本:参数设置
项目所有的参数均在这个脚本内设置,方便管理和移植。
sript.setting.py:
import osimport torch
from PIL import Image# 获取当前脚本的绝对路径
script_dir = os.path.dirname(os.path.abspath(__file__))# 获取当前脚本的上一层目录
parent_dir = os.path.dirname(script_dir)# 学习集路径
train_path = os.path.join(parent_dir, "image", "train")
# print(train_path)# 测试集路径
test_path = os.path.join(parent_dir, "image", "test")# 错误集路径
error_path = os.path.join(parent_dir, "image", "error")# 模型路径
model_path = os.path.join(parent_dir, "model")# 对应的数字
label_dict = {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4,'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}# 数字对应的类别标签
label_name = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}# 标签类别的数量
classes = len(label_dict)# 标准化参数
# normalize_values = {(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)}
normalize_mean = (0.485, 0.456, 0.406) # 标准化参数, 均值
normalize_std = (0.229, 0.224, 0.225) # 标准化参数, 标准差
normalize_size = (32, 32) # 标准化参数, 图像尺寸# batch size
batch_size = 128# 子进程数量
num_workers = 12# 网络是否预训练
is_pretrained = False# 学习率
learning_rate = 0.01# 训练轮数
num_epoches = 100# 优先使用GPU训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 图片数据出错后的默认数据
error_default_path = os.path.join(error_path, "error_default.png")
_img = Image.open(error_default_path)
error_default_img = _img.convert('RGB') # 转换成RGB模式