当前位置: 首页 > news >正文

神经网络入门—自定义神经网络续集

修改网络

神经网络入门—自定义网络-CSDN博客

修改数据集,y=x^2

# 生成一些示例数据
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)
y_train = torch.tensor([[1.0], [4.0], [9.0], [16.0]], dtype=torch.float32)

将预测代码改为,可以接收用户输入并输出

# 加载模型
loaded_model = Net()
loaded_model.load_state_dict(torch.load('model.pth'))
loaded_model.eval()  # 将模型设置为评估模式
while True:# 输入新数据进行预测num=float(input())new_input = torch.tensor([[num]], dtype=torch.float32)with torch.no_grad():prediction = loaded_model(new_input)print(f"输入 {new_input.item()} 的预测结果: {prediction.item()}")

结果

分析

训练数据x为[1.0,2.0,3.0,4.0]

x为3.0和3.5时,测试数据与训练数据较为接近,模型能较为准确预测结果

x为5.0和10.0时,测试数据与训练数据有一定差别,模型预测结果比较不准确

x为-1时,模型预测为负数,实际应为正数,因为我们的训练集没有负数,所以模型没有学到这点

重新设计网络

增加-100-100数据集

# 生成 -100 到 100 范围内的 x
x_train = torch.arange(-100, 101, dtype=torch.float32).unsqueeze(1)
# 计算对应的 y,假设 y 是 x 的平方
y_train = x_train ** 2

Loss收敛慢,网络不能拟合实际函数

即时增加到3000次迭代仍然不能解决问题/(ㄒoㄒ)/~~

问题:

  1. 模型结构过于简单:当前模型仅包含两个全连接层,对于拟合 \(y = x^2\) 这样的非线性函数,可能表达能力不够。可以增加网络的深度和宽度,例如添加更多的隐藏层。
  2. 学习率不合适:学习率太大可能会使训练过程不稳定,太小则会导致收敛速度过慢。可以尝试使用自适应学习率的优化器,如 Adam。
  3. 训练轮数不足:可以适当增加训练轮数,让模型有更多的机会学习数据的特征。

增加网络层数

class Net(nn.Module):def __init__(self):super().__init__()# 增加网络的宽度和深度self.fc1 = nn.Linear(1, 20)self.fc2 = nn.Linear(20, 20)self.fc3 = nn.Linear(20, 20)self.fc4 = nn.Linear(20, 20)self.fc5 = nn.Linear(20, 1)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)x = F.relu(x)x = self.fc3(x)x = F.relu(x)x = self.fc4(x)x = F.relu(x)x = self.fc5(x)return x

增加神经元个数

class Net(nn.Module):def __init__(self):super().__init__()# 增加网络的宽度和深度self.fc1 = nn.Linear(1, 200)self.fc2 = nn.Linear(200, 200)self.fc3 = nn.Linear(200, 200)self.fc4 = nn.Linear(200, 200)self.fc5 = nn.Linear(200, 1)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)x = F.relu(x)x = self.fc3(x)x = F.relu(x)x = self.fc4(x)x = F.relu(x)x = self.fc5(x)return x

Loss波动,疑似出现过拟合


http://www.mrgr.cn/news/97764.html

相关文章:

  • CSRF漏洞技术解析与实战防御指南
  • 【WRF理论第十七期】单向/双向嵌套机制(含namelist.input详细介绍)
  • SAP ABAP 多线程处理/并行处理的四种方式
  • Quill富文本编辑器支持自定义字体(包括新旧两个版本,支持Windings 2字体)
  • 柑橘病虫害图像分类数据集OrangeFruitDaatset-8600
  • vue3中watch的使用示例
  • NO.84十六届蓝桥杯备战|动态规划-路径类DP|矩阵的最小路径和|迷雾森林|过河卒|方格取数(C++)
  • Stable Diffusion + Contronet,调参实现LPIPS最优(带生成效果+指标对比)——项目学习记录
  • 网络协议学习
  • macos下 ragflow二次开发环境搭建
  • ABAP小白开发操作手册+(十)验证和替代——下
  • js异步机制
  • OSPF基础入门篇②:OSPF邻居建立篇-网络设备的“社交礼仪“
  • 程序代码篇---时间复杂度空间复杂度
  • 如何在Dify中安装运行pandas、numpy库(离线、在线均支持,可提供远程指导)
  • OminiAdapt:学习跨任务不变性,实现稳健且环境-觉察的机器人操作
  • MCP协议介绍
  • Spring Security 的核心配置项详解,涵盖认证、授权、过滤器链、HTTP安全设置等关键配置,结合 Spring Boot 3.x 版本最佳实践
  • ruby超高级语法
  • DDoS防御与流量优化