当前位置: 首页 > news >正文

Pytorch xpu环境配置 Pytorch使用Intel集成显卡

1、硬件集显要为Intel ARC并安装正确驱动

2、安装Intel oneAPI Base Toolkit (https://www.intel.cn/content/www/cn/zh/developer/tools/oneapi/base-toolkit-download.html)安装后大约20G左右,注意安装路径

3、安装Visual Studio Build Tools (Microsoft C++ 生成工具 - Visual Studio)

安装时所有选项默认就行,安装如下组件就行

4、安装xpu版Pytorch 安装后大约6G左右

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu 
# 使用--target=d:\python\lib修改安装路径

5、测试

每次打开CMD窗口都要执行一次setvars.bat文件(oneAPI安装路径\oneAPI\setvars.bat)然后再执行python文件,注意只能在CMD窗口中执行,不能使用PowerShell

import torchprint(torch.xpu.is_available())

一个简单的模型训练例子 

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 使用cpu时删除所有.to(xpu)和.to(cpu)plt.rcParams['font.family'] = 'SimHei'
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
xpu = torch.device('xpu') # 使用CPU时可以删除这句# 1. 定义一个简单的神经网络模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(2, 2)  # 输入层到隐藏层self.fc2 = nn.Linear(2, 1)  # 隐藏层到输出层def forward(self, x):x = torch.relu(self.fc1(x))  # ReLU 激活函数x = self.fc2(x)return x# 2. 创建模型实例
model = SimpleNN()
model.to(xpu) # 使用CPU时可以删除这句# 3. 定义损失函数和优化器
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器# 4. 假设我们有训练数据 X 和 Y
X = torch.randn(10, 2, requires_grad=True).to(xpu)  # 10 个样本,2 个特征
Y = torch.randn(10, 1).to(xpu)  # 10 个目标值
print(X,Y)
# 5. 训练循环
losses = []
for epoch in range(500):  # 训练 500 轮model.train()  # 设置模型为训练模式optimizer.zero_grad()  # 清空之前的梯度output = model(X)  # 前向传播loss = criterion(output, Y) # 计算损失losses.append(loss.item())loss.backward()  # 反向传播optimizer.step()  # 更新参数# 可视化预测结果与实际目标值对比
y_pred_final = model(X).detach().to("cpu").numpy()  # 最终预测值
y_actual = Y.to("cpu").numpy()  # 实际值plt.figure(figsize=(8, 5))
plt.plot(range(1, 11), y_actual, 'o-', label='实际值', color='blue')
plt.plot(range(1, 11), y_pred_final, 'x--', label='预测值', color='red')
plt.xlabel('Sample Index')
plt.ylabel('Value')
plt.title('Actual vs Predicted Values')
plt.legend()
plt.grid()
plt.show()


http://www.mrgr.cn/news/93021.html

相关文章:

  • QT——文件IO
  • Arduino:UNO板的接口和应用
  • unity学习62,尝试做第一个小游戏项目:flappy bird
  • Spring MVC 返回数据
  • CentOS 7.9 安装 ClickHouse 文档
  • python学习第三天
  • 【Transformer优化】什么是稀疏注意力?
  • ubuntu离线安装nvidia-container-runtime
  • NUDT Paper LaTeX 模板使用
  • Solana 核心概念全解析:账户、交易、合约与租约,高流量区块链技术揭秘!
  • GitLab常用操作
  • 第二节:基于Winform框架的串口助手小项目---创建界面《C#编程》
  • HarmonyOS NEXT开发进阶(十一):应用层架构介绍
  • unity pico开发二:连接头盔,配置手柄按键事件
  • 【和春笋一起学C++】逻辑操作符和条件操作符
  • MySQL快速搭建主从复制
  • 【C++指南】一文总结C++类和对象【中】
  • Nginx1.19.2不适配OPENSSL3.0问题
  • NL2SQL-基于Dify+阿里通义千问大模型,实现自然语音自动生产SQL语句
  • 小白向:如何使用dify官方市场“ECharts图表生成”工具插件——dify入门案例