当前位置: 首页 > news >正文

过拟合与欠拟合、批量标准化

过拟合与欠拟合

过拟合(Overfitting)

1、基本概念:过拟合指的是模型在训练数据上表现很好,但在未见过的测试数据上表现较差的情况。过拟合发生的原因是模型过于复杂,能够记住训练数据的细节和噪声,而不是学习数据的通用模式。

2、特征

  • 模型在训练数据上的准确度高。

  • 模型在测试数据上的准确度较低。

  • 模型的参数数量过多,容易记忆训练数据。

3、防止过拟合的方法

  • 数据集扩增:增加更多的训练数据,可以减少过拟合的风险。

  • 正则化:通过添加正则化项,如L1正则化(Lasso)或L2正则化(Ridge),来惩罚模型参数的大小,使模型更简单。

  • 特征选择:选择最重要的特征,降低模型的复杂度。

  • 交叉验证:使用交叉验证来估计模型的性能,选择最佳的模型参数。

  • 早停止:在训练过程中监控验证集的性能,当性能开始下降时停止训练,以防止过拟合。

欠拟合(Underfitting)

1、基本概念:欠拟合表示模型太过简单,无法捕获数据中的关键特征和模式。模型在训练数据和测试数据上的性能都较差。

2、特征

  • 模型在训练数据上的准确度较低。

  • 模型在测试数据上的准确度也较低。

  • 模型可能太简单,参数数量不足。

3、防止欠拟合的方法

  • 增加模型复杂度:使用更复杂的模型,例如增加神经网络的层数或增加决策树的深度。

  • 增加特征:添加更多的特征或进行特征工程,以捕获更多数据的信息。

  • 减小正则化强度:如果使用了正则化,可以降低正则化的强度,使模型更灵活。

  • 调整超参数:调整模型的超参数,如学习率、批量大小等,以改善模型的性能。

  • 使用更多数据:如果可能的话,增加训练数据可以提高模型的性能。

总的来说,过拟合和欠拟合都是需要非常注意的问题。

选择合适的模型复杂度、正则化方法和特征工程技巧可以帮助在训练机器学习模型时避免这些问题,获得更好的泛化性能。

解决过拟合

  • L1 正则化 更适合用于产生稀疏模型,会让部分权重完全为零,适合做特征选择。

  • L2 正则化 更适合平滑模型的参数,避免过大参数,但不会使权重变为零,适合处理高维特征较为密集的场景。

L2正则化

L2 正则化通过在损失函数中添加权重参数的平方和来实现,目标是惩罚过大的参数值。

L_{\text{total}}(\theta) = L(\theta) + \lambda \cdot \frac{1}{2} \sum_{i} \theta_i^2

  • L(\theta) 是原始损失函数(比如均方误差、交叉熵等)。

  • \lambda 是正则化强度,控制正则化的力度。

  • \theta_i是模型的第 $$i$$ 个权重参数。

  • \frac{1}{2} \sum_{i} \theta_i^2 是所有权重参数的平方和,称为 L2 正则化项。

L2 正则化会惩罚权重参数过大的情况,通过参数平方值对损失函数进行约束。

梯度更新

\theta_{t+1} = \theta_t - \eta \left( \nabla L(\theta_t) + \lambda \theta_t \right)

  • \eta 是学习率。

  • \nabla L(\theta_t)是损失函数关于参数\theta_t的梯度。

  • \lambda \theta_t 是 L2 正则化项的梯度,对应的是参数值本身的衰减。

参数越大惩罚力度就越大,从而让参数逐渐趋向于较小值,避免出现过大的参数。

作用

  1. 防止过拟合:当模型过于复杂、参数较多时,模型会倾向于记住训练数据中的噪声,导致过拟合。L2 正则化通过抑制参数的过大值,使得模型更加平滑,降低模型对训练数据噪声的敏感性。

  2. 限制模型复杂度:L2 正则化项强制权重参数尽量接近 0,避免模型中某些参数过大,从而限制模型的复杂度。通过引入平方和项,L2 正则化鼓励模型的权重均匀分布,避免单个权重的值过大。

  3. 提高模型的泛化能力:正则化项的存在使得模型在测试集上的表现更加稳健,避免在训练集上取得极高精度但在测试集上表现不佳。

  4. 平滑权重分布:L2 正则化不会将权重直接变为 0,而是将权重值缩小。这样模型就更加平滑的拟合数据,同时保留足够的表达能力。

import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)  # L2 正则化

L1正则化

L1 正则化通过在损失函数中添加权重参数的绝对值之和来约束模型的复杂度。

L_{\text{total}}(\theta) = L(\theta) + \lambda \sum_{i} |\theta_i|

  • L(\theta)是原始损失函数。

  • \lambda 是正则化强度,控制正则化的力度。

  • |\theta_i| 是模型第 i 个参数的绝对值。

  • \sum_{i} |\theta_i|是所有权重参数的绝对值之和,这个项即为 L1 正则化项。

L1 正则化依赖于参数的绝对值,其梯度更新时不是简单的线性缩小,而是通过符号函数来直接调整参数的方向。

作用
  1. 稀疏性:L1 正则化的一个显著特性是它会促使许多权重参数变为 。这是因为 L1 正则化倾向于将权重绝对值缩小到零,使得模型只保留对结果最重要的特征,而将其他不相关的特征权重设为零,从而实现 特征选择 的功能。

  2. 防止过拟合:通过限制权重的绝对值,L1 正则化减少了模型的复杂度,使其不容易过拟合训练数据。相比于 L2 正则化,L1 正则化更倾向于将某些权重完全移除,而不是减小它们的值。

  3. 简化模型:由于 L1 正则化会将一些权重变为零,因此模型最终会变得更加简单,仅依赖于少数重要特征。这对于高维度数据特别有用,尤其是在特征数量远多于样本数量的情况下。

  4. 特征选择:因为 L1 正则化会将部分权重置零,因此它天然具有特征选择的能力,有助于自动筛选出对模型预测最重要的特征。

l1_lambda = 0.001
# 计算 L1 正则化项并将其加入到总损失中
l1_norm = sum(p.abs().sum() for p in model.parameters())
loss = loss + l1_lambda * l1_norm

Dropout

Dropout 是一种在训练过程中随机丢弃部分神经元的技术。它通过减少神经元之间的依赖来防止模型过于复杂,从而避免过拟合。

import torch
import torch.nn as nndef dropout():dropout=nn.Dropout(p=0.5)x=torch.randn(2,2)print(x)print("------------------")print(dropout(x))if __name__ == "__main__":dropout()"""
tensor([[-0.3970, -1.8862],[-0.5632,  0.0390]])
------------------
tensor([[-0.7940, -3.7724],[-1.1264,  0.0000]])
"""

Dropout过程:

  1. 按照指定的概率把部分神经元的值设置为0;

  2. 为了规避该操作带来的影响,需对非 0 的元素使用缩放因子1/(1-p)进行强化。

权重影响
 

简化模型

  • 减少网络层数和参数: 通过减少网络的层数、每层的神经元数量或减少卷积层的滤波器数量,可以降低模型的复杂度,减少过拟合的风险。

  • 使用更简单的模型: 对于复杂问题,使用更简单的模型或较小的网络架构可以减少参数数量,从而降低过拟合的可能性。

数据增强

通过对训练数据进行各种变换(如旋转、裁剪、翻转、缩放等),可以增加数据的多样性,提高模型的泛化能力。

from torchvision import transformstransform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(10),transforms.ToTensor()
])"""
transforms.Compose([...])
transforms.Compose接受一个转换操作列表,并按顺序应用这些转换。它允许你将多个转换操作组合在一起,形成一个统一的转换流程。转换操作列表
1. transforms.RandomHorizontalFlip()
功能:随机水平翻转图像。
参数:默认情况下,p=0.5,即有50%的概率会对图像进行水平翻转。
用途:用于数据增强,增加模型的泛化能力。
2. transforms.RandomVerticalFlip()
功能:随机垂直翻转图像。
参数:默认情况下,p=0.5,即有50%的概率会对图像进行垂直翻转。
用途:同样用于数据增强,使模型能够更好地处理不同角度的图像。
3. transforms.RandomRotation(10)
功能:随机旋转图像。
参数:degrees,表示旋转的角度范围。这里的10表示图像将以10度为范围进行随机旋转(-10到10度之间)。
用途:增强数据,使模型能够更好地处理不同角度下的图像。
4. transforms.ToTensor()
功能:将PIL Image或numpy数组转换为PyTorch的Tensor,并且归一化到[0.0, 1.0]。
用途:将图像数据转换为适合在网络中使用的格式,同时进行归一化处理。
"""

早停

早停是一种在训练过程中监控模型在验证集上的表现,并在验证误差不再改善时停止训练的技术。这样可避免训练过度,防止模型过拟合。

模型集成

通过将多个不同模型的预测结果进行集成,可以减少单个模型过拟合的风险。常见的集成方法包括投票法、平均法和堆叠法。

 交叉验证

使用交叉验证技术可以帮助评估模型的泛化能力,并调整模型超参数,以防止模型在训练数据上过拟合。

这些方法可以单独使用,也可以结合使用,以有效地防止参数过大和过拟合。根据具体问题和数据集的特点,选择合适的策略来优化模型的性能。

批量标准化

批量标准化(Batch Normalization, BN)是一种广泛使用的神经网络正则化技术,核心思想是对每一层的输入进行标准化,然后进行缩放和平移,旨在加速训练、提高模型的稳定性和泛化能力。

实现过程

批量标准化的基本思路是在每一层的输入上执行标准化操作,并学习两个可训练的参数:缩放因子 \gamma和偏移量\beta

1.计算均值和方差

均值

\mu_B = \frac{1}{m} \sum_{i=1}^m x_i

方差

\sigma_B^2 = \frac{1}{m} \sum_{i=1}^m (x_i - \mu_B)^2

2.标准化

使用计算得到的均值和方差对数据进行标准化,使得每个特征的均值为0,方差为1。

标准化后的值

\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}

3.缩放和平移

标准化后的数据通常会通过可训练的参数进行缩放和平移,以恢复模型的表达能力。

  • 缩放(Gamma)
    y_i = \gamma \hat{x}_i
     

  • 平移(Beta)
    y_i = \gamma \hat{x}_i + \beta

训练和推理阶段

  • 训练阶段: 在训练过程中,均值和方差是基于当前批次的数据计算得到的。

  • 推理阶段: 在推理阶段,批量标准化使用的是训练过程中计算得到的全局均值和方差,而不是当前批次的数据。这些全局均值和方差通常会被保存在模型中,用于推理时的标准化过程。

作用

提高神经网络的训练稳定性、加速训练过程并减少过拟合

可以从一下几个方面来改善:

1 缓解梯度问题

标准化处理可以防止激活值过大或过小,避免了激活函数(如 Sigmoid 或 Tanh)饱和的问题,从而缓解梯度消失或爆炸的问题。

2 加速训练

由于 BN 使得每层的输入数据分布更为稳定,因此模型可以使用更高的学习率进行训练。这可以加快收敛速度,并减少训练所需的时间。

3 减少过拟合

  • 类似于正则化:虽然 BN 不是一种传统的正则化方法,但它通过对每个批次的数据进行标准化,可以起到一定的正则化作用。它通过在训练过程中引入了噪声(由于批量均值和方差的估计不完全准确),这有助于提高模型的泛化能力。

  • 避免对单一数据点的过度拟合:BN 强制模型在每个批次上进行标准化处理,减少了模型对单个训练样本的依赖。这有助于模型更好地学习到数据的整体特征,而不是对特定样本的噪声进行过度拟合。

import torch 
import torch.nn as nndef test():x=torch.randn(2,3,4,4)print(x)bn=nn.BatchNorm2d(3)print(bn)if __name__=='__main__':test()"""
tensor([[[[-0.1629, -0.3630,  0.6086,  1.2669],[-0.7454,  0.1635, -0.1000,  0.9490],[ 0.2711,  1.9755, -0.6669, -0.3346],[-0.2467,  0.9544, -0.3537, -0.8904]],[[ 0.9441, -0.7221, -0.0377,  0.3374],[-2.2795, -1.1555,  0.9555,  0.4566],[-0.1251,  0.5129, -1.6877, -0.3519],[ 0.5455,  1.1250,  0.6385, -0.1447]],[[ 0.8211,  0.2494, -0.4131,  1.2432],[-0.6434,  1.1120,  1.1102,  0.8328],[ 0.0868,  0.2222,  0.1554, -0.7188],[ 0.8627, -0.7993, -0.8812,  0.9972]]],[[[-1.1189, -0.9412, -0.9145, -0.0048],[ 0.6170,  1.2101,  0.1813, -0.5363],[ 0.9798,  0.4064,  0.5711,  0.2156],[ 1.6940,  0.4776,  0.1171, -0.1421]],[[ 0.6171, -1.2645, -0.1189, -0.3172],[-0.5279,  0.3126,  1.5111,  0.8772],[-1.2101, -1.4024,  1.5457, -1.0882],[-1.4969, -0.3039, -0.6469, -0.3612]],[[-0.8796,  0.6566,  1.0026,  0.2472],[ 0.6985, -0.4325,  0.5768,  1.2399],[-0.8927, -0.3637, -0.5471, -1.9263],[ 0.9424,  1.6031,  1.5086,  0.2109]]]])
BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
"""


http://www.mrgr.cn/news/28809.html

相关文章:

  • 嵌入式课程day14-C语言指针进阶
  • 代码修改材质参数
  • blenderFds代码解读
  • Openshift 如何更新访问控制机
  • 力扣每日一题 3261. 统计满足 K 约束的子字符串数量 II
  • 从0开始学习机器学习--Day24--核函数
  • docker- No space left on device
  • 开源模型应用落地-qwen模型小试-调用Qwen2-VL-7B-Instruct-更清晰地看世界(一)
  • 紧急预警!台风贝碧嘉正面袭击上海浦东,风雨交加影响全城
  • 自然语言处理实战项目
  • 文件标识符fd
  • 【看这里】记录一下,如何在springboot中使用EasyExcel并行导出多个Excel文件并压缩zip后下载
  • Java 性能调优:优化 GC 线程设置
  • 【C++前后缀分解】1653. 使字符串平衡的最少删除次数|1793
  • DFS:二叉树中的深搜
  • Qt_输入类控件
  • 破损shp文件修复
  • 代码随想录算法训练营第57天|卡码网 53. 寻宝 prim算法精讲和kruskal算法精讲
  • 中位数贪心+分组,CF 433C - Ryouko‘s Memory Note
  • C++基于select和epoll的TCP服务器
  • 问题——IMX6UL的uboot无法ping主机或Ubuntu
  • 基于形状记忆聚合物的折纸超结构
  • 速通LLaMA2:《Llama 2: Open Foundation and Fine-Tuned Chat Models》全文解读
  • 【Elasticsearch系列九】控制台实战
  • 视频工具EasyDarwin将本地视频生成RTSP给WVP拉流列表
  • 螺丝、螺母、垫片等紧固件常用类型详细介绍