Python 神经网络项目常用语法
- 一、 工具
- 1. 导入模块和包
- 2. 修改系统路径 (sys.path.append)
- 3. 命令行参数解析 (argparse 模块)
- 4. 字符串格式化
- 5. 获取设备
- 6. 加载数据
- 7.设置推理结果的子路径
- 8. main() 脚本入口点
- 二、类相关
- 1. 类的定义及初始化
- 2. 类的实例化及函数调用
- 三、神经网络常用类
- 1. 工具类
- 1.1 正弦位置编码类
- 1.2 上采样/下采样
- 1.3 标准化
- 1.4 归一化层
- 1.5 提取函数 extract
- 2. 构建网络块类
- 2.1 block()
- 2.2 残差连接
- 2.3 ResnetBlock
- 2.3 Attention
一、 工具
1. 导入模块和包
import os
import argparse
import sys as s
from accelerate import Accelerator
import
:用于导入模块和包,可以选择导入单个模块、多个模块。from ... import ...
:从特定模块中导入具体的类、函数或变量。impoert ... as ...
:可以为导入的模块指定别名,使代码简洁。
2. 修改系统路径 (sys.path.append)
# 返回当前脚本所在目录的父目录
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) # 返回当前脚本所在目录的上上级目录
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
这两行代码通过 os.path.dirname(__file__)
获取当前脚本所在的目录,然后使用 ..
依次返回上一级目录。
sys.path.append()
:修改模块搜索路径,可以动态添加额外的搜索路径,以便访问其他目录中的模块。os.path.join(a, b, ...)
:将多个路径部分连接成一个完整的路径,并确保使用正确的路径分隔符(在 UNIX 系统上是/
,在 Windows 上是\
)。- 例如,
os.path.join("/home/user", "project")
将返回/home/user/project
。
- 例如,
os.path.dirname(__file__)
和..
、..
是用来构建路径的。__file__
是一个内置变量,它指向当前脚本文件的路径(字符串格式)。例如,如果脚本文件的路径是/home/user/project/scripts/train.py
,则__file__
就是这个文件的完整路径。os.path.dirname(path)
函数返回给定路径 path 的父目录(即去掉路径中的文件名部分)。例如,如果__file__
是/home/user/project/scripts/train.py
,那么os.path.dirname(__file__)
将返回/home/user/project/scripts
,即文件所在的目录。..
在路径中代表上一级目录,是一个相对路径。如果当前路径是/home/user/project/scripts
,则os.path.join("/home/user/project/scripts", '..')
会返回/home/user/project
,即当前脚本所在目录的父目录。
3. 命令行参数解析 (argparse 模块)
import argparse# 创建 ArgumentParser 对象,description 参数会显示在帮助信息中
parser = argparse.ArgumentParser(description='Train EBM model')# 添加命令行参数,指定名称、默认值、类型、帮助信息
parser.add_argument('--model_type', default="states", type=str,help='choices: states | thetas')
parser.add_argument('--dataset', default='jellyfish', type=str, help='dataset to evaluate')
parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
parser.add_argument('--epochs', default=10, type=int, help='Number of epochs to train the model')# 解析命令行参数并将其存储在 FLAGS 对象中
FLAGS = parser.parse_args()# 使用条件判断 if-elseif-else 语句,根据 model_type 选择相应的操作
if FLAGS.model_type == "states":...
elif FLAGS.model_type == "thetas":...# 输出解析的参数值
print("Dataset:", FLAGS.dataset)
print("Batch size:", FLAGS.batch_size)
print("Epochs:", FLAGS.epochs)
在大多数深度学习训练脚本中,使用 argparse
模块来处理命令行参数。
parser = argparse.ArgumentParser()
:创建一个命令行参数解析器parser
。parser
是通过argparse.ArgumentParser()
创建的实例,它负责定义和解析命令行输入。parser.add_argument()
:添加命令行参数,指定名称、类型、默认值及帮助信息。default
:参数的默认值,如果命令行未提供该参数,则使用默认值。type
:定义参数的数据类型。help
:提供该参数的帮助描述。
parser.parse_args()
:从命令行中解析传入的参数,并将它们存储在 args 对象中。args 是一个命名空间对象,可以通过点语法访问各个参数。
运行命令行的不同示例:
- 不传递任何参数,使用默认值:
输出:# 命令行python script.py
Dataset: jellyfish Batch size: 4 Epochs: 10
- 传递自定义的命令行参数:
输出:# 命令行 python script.py --dataset "fish" --batch_size 8 --epochs 20
Dataset: fish Batch size: 8 Epochs: 20
- 查看帮助信息:
输出:python script.py --help
usage: script.py [-h] [--dataset DATASET] [--batch_size BATCH_SIZE] [--epochs EPOCHS] Train EBM model optional arguments:-h, --help show this help message and exit--dataset DATASET Dataset to evaluate--batch_size BATCH_SIZE Batch size for training--epochs EPOCHS Number of epochs to train the model
4. 字符串格式化
print("Saved at: ", results_path)
print("DATA_PATH: ", DATA_PATH)
print("number of parameters in model: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
print()
:用于输出信息。- 字符串连接:
print
语句中使用逗号分隔多个变量,可以连接字符串和变量。
5. 获取设备
def get_device():return torch.device("cuda:4" if torch.cuda.is_available() else "cpu")args.device = get_device()
在深度学习任务中,通常需要指定训练的设备,如果你的机器有支持的 GPU,cuda 可以加速模型训练。如果没有 GPU,则会使用 CPU。
get_device()
这个函数会检查当前是否有可用的 GPU。
torch.device("cuda:4")
中的cuda:4
表示选择第 5 个 GPU。如果你的机器没有 5 个以上的 GPU,可能会报错。你可以使用cuda
或cuda:0
(通常默认使用第一个 GPU)。torch.cuda.is_available()
用来检查当前是否有可用的 GPU,如果没有返回 False,则会回退到 CPU。
args.device = get_device()
调用 get_device()
函数,将返回的设备对象(cuda:4
或 cpu
)赋值给 args.device
。args.device
之后可以用来指定模型的设备,确保模型训练或推理时使用正确的硬件资源。
上述代码可修改为:
def get_device():if torch.cuda.is_available():# 如果有多个 GPU,选择第一个 GPU,避免 'cuda:4' 报错return torch.device("cuda:0") else:return torch.device("cpu")args.device = get_device()
6. 加载数据
def cycle(dl):while True:for data in dl:yield data
无限循环地返回数据加载器中的数据。
7.设置推理结果的子路径
current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
args.inference_result_subpath = os.path.join(args.inference_result_path,current_time + "_coeff_ratio_w_{}_J{}_".format(args.coeff_ratio_w, args.coeff_ratio_J)
)
这段代码构建推理结果保存的子目录路径 inference_result_subpath
,这个路径将包含当前时间戳以及与 coeff_ratio_w
和 coeff_ratio_J
参数相关的信息。
- 使用
datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
获取当前的时间戳,格式为年-月-日_时-分-秒
。 - 使用
args.coeff_ratio_w
和args.coeff_ratio_J
格式化字符串,这两个参数的值会被嵌入路径中。这些参数可能是用户在命令行中设置的值,影响模型训练或推理过程中的某些超参数。 - 最终路径会类似于:
inference_result_path/2024-11-17_12-30-45_coeff_ratio_w_0.3_J0.5_
。
8. main() 脚本入口点
# inference.py
def main(args):diffusion = load_model(args)dataloader = load_data(args)inference(dataloader, diffusion, args)if __name__ == '__main__':parser = argparse.ArgumentParser(description='xxx model')...args = parser.parse_args()main(args) # # 调用 main 函数
main 函数通常是脚本的核心逻辑部分,处理程序的主要任务。在 main() 函数中,可以使用 args 对象中的参数值来决定程序的行为,比如设置模型的超参数、选择训练模式等。
这段代码展示了 main 函数的核心逻辑,主要包括以下几个步骤:
-
加载模型(
diffusion = load_model(args)
):这行代码通过调用 load_model 函数加载模型及相关组件。args 参数传递了外部配置,可能包含模型路径、超参数等信息。load_model 函数返回 diffusion 扩散模型对象。 -
加载数据(
dataloader = load_data(args)
):这行代码调用 load_data 函数加载数据。args 中包含有关数据的路径或其他配置信息。dataloader 通常是一个生成数据批次(batch)的迭代器,可以用于训练或推理过程。 -
推理(
inference(dataloader, diffusion, args)
)在这行代码中,inference 函数接受输入数据的迭代器、扩散模型和配置信息参数,inference 函数通常用于执行模型的推理过程,生成预测或结果。
if __name__ == '__main__'
这一行的作用是确保当这个脚本作为主程序执行时,main() 函数会被调用。main(args)
这行代码调用 main() 函数,并将 args 对象作为参数传递给它。
因此,如果这个脚本是作为独立的 Python 文件运行的,它会执行 main 函数。
当你运行命令 python inference.py
时:
- Python 解释器会执行整个
inference.py
文件。 - 脚本中的
if __name__ == '__main__':
块会被触发。 - 然后会调用 main(args) 函数,并传入通过 argparse 解析的命令行参数 args。
这意味着,main 函数将会被执行,在 main 中使用命令行参数 args 来配置模型、加载数据、进行推理等。
二、类相关
要想使用类,先使用 class 关键字去定义类,而如何使用类就在类方法中定义,然后再用实例化对象去调用类方法。
-
有个特殊参数每个类都有,它叫
self
,它代表类的当前实例,使得方法可以访问和修改实例的属性。 -
有个特殊的类方法每个类都会定义,它叫
__init__
类初始化方法或构建方法,它会在这个类创建实例化对象时自动被调用,并且传入实例化时的参数。通常在该方法中对实例属性进行初始化。
# 定义一个类
class Dog:def __init__(self, name, age): # 初始化方法self.name = name # 初始化属性self.age = agedef speak(self): # 类方法print(f"{self.name} says woof!")# 实例化一个对象
dog1 = Dog("Buddy", 3) # 创建 Dog 类的实例,__init__ 方法自动执行# 调用类方法
dog1.speak() # 输出: Buddy says woof!
Dog
是类,定义了__init__
方法来初始化 name 和 age 属性。speak()
是一个类方法,它输出 name 属性的内容。- 通过
dog1 = Dog("Buddy", 3)
创建了一个Dog
类的实例,__init__
方法自动被调用,初始化了 dog1 的属性。 dog1.speak()
实例化对象 dog1 调用类方法,使用类。
1. 类的定义及初始化
__init__
是 Python 中的类初始化方法,也叫构造方法。它用于在类的对象被创建时初始化对象的状态(即设置对象的属性)。
__init__
方法会在类的实例化时自动调用,并且在对象创建后执行。
class ClassName:def __init__(self, parameters):# 初始化属性或执行其他必要的操作self.attribute = value# 其他代码
__init__(self, parameters)
:__init__
方法接受至少一个参数,通常是 self,它表示类的实例对象。parameters 是该方法接受的其他参数,用于在初始化时传递值。self.attribute
:self 表示当前对象实例,可以通过它访问类的属性和方法。- value 是初始化时为属性赋的值,可以是常量、变量或通过其他逻辑生成的值。
示例 1:基础初始化方法
class Person:def __init__(self, name, age):# 初始化时将名字和年龄赋值给对象的属性self.name = nameself.age = agedef introduce(self):print(f"Hello, my name is {self.name} and I am {self.age} years old.")
在这个例子中,Person 类的 __init__
方法接受两个其他参数:name 和 age,并将它们赋值给实例对象的属性 self.name
和 self.age
。
# 创建对象时,初始化属性
person1 = Person("Alice", 30)
person2 = Person("Bob", 25)# 调用方法
person1.introduce() # 输出: Hello, my name is Alice and I am 30 years old.
person2.introduce() # 输出: Hello, my name is Bob and I am 25 years old.
示例 2:使用默认参数
class Car:def __init__(self, make, model, year=2020):# 初始化时,year 如果未传递将默认为 2020self.make = makeself.model = modelself.year = yeardef display_info(self):print(f"{self.year} {self.make} {self.model}")
在这个例子中,year 参数具有默认值 2020。如果在创建 Car 实例时未传递 year 参数,它会自动使用默认值 2020。
# 创建时传递所有参数
car1 = Car("Toyota", "Camry", 2021)# 创建时只传递 make 和 model,year 会使用默认值 2020
car2 = Car("Honda", "Civic")# 输出汽车信息
car1.display_info() # 输出: 2021 Toyota Camry
car2.display_info() # 输出: 2020 Honda Civic
2. 类的实例化及函数调用
# train_script.py# Unet3D_with_Conv3D、GaussianDiffusion、Trainer 是从模块中导入的类
from model.video_diffusion_pytorch.video_diffusion_pytorch_conv3d import Unet3D_with_Conv3D
from diffusion.diffusion_2d_jellyfish import GaussianDiffusion, Trainerif __name__ == "__main__":# 解析命令行参数并将其存储在 FLAGS 对象中FLAGS = parser.parse_args()# 创建 Unet3D_with_Conv3D 模型实例model = Unet3D_with_Conv3D(dim = 64, # 设置模型的基础维度大小out_dim = 1 if FLAGS.only_vis_pressure else 3, # 根据命令行参数 only_vis_pressure 决定输出维度dim_mults = (1, 2, 4), # 传递一个元组作为参数,用于指定每个网络层维度的倍数channels=5 if FLAGS.only_vis_pressure else 7 # 根据命令行参数 only_vis_pressure 决定通道数)# 创建 GaussianDiffusion 实例diffusion = GaussianDiffusion(model,image_size = 64,frames=FLAGS.frames,cond_steps=FLAGS.cond_steps,timesteps = 1000, # 设置扩散步骤数sampling_timesteps = 250, # 采样步骤数loss_type = 'l2', # 设置损失函数类型:L1 or L2objective = "pred_noise",device =device # 模型运行的设备(CPU/GPU))# 创建 Trainer 类的实例,该类用于管理模型的训练trainer = Trainer(diffusion,FLAGS.dataset,FLAGS.dataset_path,FLAGS.frames,FLAGS.traj_len,FLAGS.ts,FLAGS.log_path,train_batch_size = FLAGS.batch_size, # 训练的批次大小train_lr = 1e-3, # 学习率train_num_steps = 400000, # 总训练步数gradient_accumulate_every = 1, # 指定进行梯度累积的次数ema_decay = 0.995, # 用于模型参数的指数移动平均值的衰减因子save_and_sample_every = 4000, # 每 4000 步保存模型和进行采样results_path = results_path,amp = False, # 是否使用混合精度训练calculate_fid = False, # 训练过程中是否计算 fidis_testdata = FLAGS.is_testdata,only_vis_pressure = FLAGS.only_vis_pressure,model_type = FLAGS.model_type)trainer.train() # 调用 Trainer 类的 train 方法,启动模型的训练过程
这段代码展示了如何定义和使用深度学习模型的训练流程,包括模型定义、模型实例化、训练参数设置,以及如何通过面向对象编程实现模块化。if __name__ == "__main__"
使得这段代码在直接运行脚本时会执行训练逻辑,而在导入时不会执行,从而提高了代码的复用性和模块化水平。
-
if __name__ == "__main__"
:它是模块和脚本的运行入口。该语句下的代码仅在该脚本作为主程序运行时才会被执行。__name__
变量:每个 Python 模块都有一个内置属性__name__
,其值决定了模块是被导入还是直接运行。__main__
:当一个 Python 文件被直接运行时,__name__
的值会被设置为__main__
。- 导入时的行为:如果该模块被其他脚本导入,
__name__
的值是该模块的文件名(不带路径和.py
扩展名)。
-
model = ClassName(arguments)
:创建类的实例,通过类构造函数ClassName
初始化对象model
。 -
out_dim = 1 if FLAGS.only_vis_pressure else 3
:使用条件表达式(类似三元运算符)来设置输出维度,如果FLAGS.only_vis_pressure
为真,out_dim
为 1,否则为 3。
示例:
# main_script.py
if __name__ == "__main__":print("This will only run when main_script.py is executed directly.")
运行结果:
- 如果运行
python main_script.py
,将输出This will only run when main_script.py is executed directly.
- 如果
main_script.py
被其他脚本导入,如import main_script
,这行代码不会被执行。
三、神经网络常用类
1. 工具类
1.1 正弦位置编码类
常见于自然语言处理(如 Transformer)和其他序列建模任务中。正弦位置编码是一种通过正弦和余弦函数表示位置的方式,使得模型能够感知输入数据中元素的顺序。
class SinusoidalPosEmb(nn.Module):def __init__(self, dim):super().__init__()self.dim = dimdef forward(self, x):device = x.devicehalf_dim = self.dim // 2emb = math.log(10000) / (half_dim - 1)emb = torch.exp(torch.arange(half_dim, device=device) * -emb)emb = x[:, None] * emb[None, :]emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb
- dim 是一个超参数,表示生成的位置编码的维度(即输出编码的总维度)。通常,这个维度应当是偶数,因为在后续计算中使用了正弦和余弦函数。
- forward 方法定义了正弦位置编码的计算过程。
half_dim = self.dim // 2
:将 dim 除以 2,计算出编码的一半维度,这将用于后续的正弦和余弦计算。emb = math.log(10000) / (half_dim - 1)
:计算一个常数,它与位置编码的尺度有关。这里的 10000 是一个经验值,常用于位置编码的计算。torch.exp(torch.arange(half_dim, device=device) * -emb)
:生成一个递减的指数序列。
torch.arange(half_dim)
生成一个从 0 到half_dim-1
的整数序列。然后通过乘以-emb
,再取指数,得到一组缩放因子,这些因子将用于后续的正弦和余弦函数。x[:, None]
:通过切片操作将 x 的维度从 [batch_size] 扩展到 [batch_size, 1],将其广播到与 emb 相乘。这意味着每个位置 x 的值将与 emb 中的每个尺度因子相乘,生成一个位置的尺度序列。
emb[None, :]
:emb 被扩展到 [1, half_dim],然后与x[:, None]
进行广播相乘,生成一个包含每个位置的编码因子。emb.sin()
和emb.cos()
:对每个位置的编码因子,分别计算正弦和余弦值。
torch.cat((emb.sin(), emb.cos()), dim=-1)
:将计算得到的正弦和余弦值沿最后一个维度(dim=-1
)连接起来,形成最终的编码。这将使得每个位置的编码有两倍的维度(half_dim 为每种类型,正弦和余弦各一半)。return emb
:返回生成的位置编码张量 emb。
位置编码为输入序列的每个位置生成一个向量,使得模型可以感知不同元素的相对位置。位置编码通过正弦和余弦函数的组合来实现,这种设计能够让模型在不同的尺度上感知位置信息,且不依赖于具体的训练数据。
该方法的关键优势是它为每个位置生成了一个独特的编码,不同的位置信息通过正弦和余弦函数映射到不同的维度上,能够捕捉到位置之间的相对关系。
RandomOrLearnedSinusoidalPosEmb
:可以选择使用随机的或学习的正弦位置编码。
class RandomOrLearnedSinusoidalPosEmb(nn.Module):def __init__(self, dim, is_random=False):super().__init__()assert (dim % 2) == 0half_dim = dim // 2self.weights = nn.Parameter(torch.randn(half_dim), requires_grad=not is_random)def forward(self, x):x = rearrange(x, 'b -> b 1')freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
1.2 上采样/下采样
Upsample
:上采样模块,使用最近邻插值方法和卷积层进行特征图的上采样。
def Upsample(dim, dim_out = None):return nn.Sequential(# 上采样操作,将输入特征图的尺寸扩大 2 倍,使用最近邻插值进行上采样nn.Upsample(scale_factor = 2, mode = 'nearest'),# 卷积层,输入通道数为 dim,输出通道数为 dim_out(如果 dim_out 未定义,则默认为 dim),卷积核大小为 3x3nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1))
Downsample
:下采样模块,使用Rearrange
层对特征图进行降维。
def Downsample(dim, dim_out = None):return nn.Sequential(# 输入张量的形状是 (batch_size, c, h, w),通过 p1=2 和 p2=2,它将 h 和 w 都按 2 的倍数重排列,# 因此,每个空间位置的特征将从 c 维度与 h 和 w 的子块合并,形成 dim * 4 个特征通道。这就相当于一个下采样操作,将空间尺寸减小一半。Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),# 卷积层,用于将通道数从 dim * 4 转换为 dim_out,dim_out 默认为 dim。积核大小为 1x1,通常用于改变通道数而不改变空间尺寸。nn.Conv2d(dim * 4, default(dim_out, dim), 1))
这两个函数 Upsample
和 Downsample
定义了数据的上采样和下采样操作。它们对输入数据的空间分辨率和通道数进行处理,其工作方式如下:
- 上采样
Upsample(dim, dim_out)
:将空间分辨率放大 2 倍(从(H, W)
到(2H, 2W)
),并通过卷积调整通道数。
-
nn.Upsample(scale_factor=2, mode='nearest')
:- 使用最近邻插值法,将输入的空间尺寸(高度和宽度)扩大为原来的 2 倍。
- 例如,输入张量大小为
(B, C, H, W)
,则经过这一步后,大小变为(B, C, 2H, 2W)
。
-
nn.Conv2d(dim, dim_out, 3, padding=1)
:- 卷积层,使用 3 × 3 3 \times 3 3×3 的卷积核对放大的数据进行处理,调整通道数为
dim_out
。 padding=1
确保空间分辨率保持不变(仍为(2H, 2W)
)。- 如果
dim_out
未指定,默认设置为dim
,即输出的通道数与输入一致。
- 卷积层,使用 3 × 3 3 \times 3 3×3 的卷积核对放大的数据进行处理,调整通道数为
最终,Upsample
的效果是:
- 输入维度:
(B, C, H, W)
- 输出维度:
(B, dim_out, 2H, 2W)
- 下采样
Downsample(dim, dim_out)
:将空间分辨率缩小 2 倍(从(H, W)
到(H/2, W/2)
),并通过通道重排和卷积调整通道数。
-
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2)
:- 通过重排张量,将空间维度分块到通道维度。
- 将每 2 × 2 2 \times 2 2×2 的块合并成新的通道。
- 输入的大小为
(B, C, H, W)
,输出的大小变为(B, 4C, H/2, W/2)
。
重排细节:
- 空间维度
(H, W)
被分成H/2
和W/2
,每个分块是 2 × 2 2 \times 2 2×2。 C
通道被扩展为4C
,因为 2 × 2 = 4 2 \times 2 = 4 2×2=4。
-
nn.Conv2d(dim * 4, dim_out, 1)
:- 使用 1 × 1 1 \times 1 1×1 的卷积核对重新排列的通道进行压缩,调整通道数为
dim_out
。 - 如果
dim_out
未指定,默认值是dim
。
- 使用 1 × 1 1 \times 1 1×1 的卷积核对重新排列的通道进行压缩,调整通道数为
最终,Downsample
的效果是:
- 输入维度:
(B, C, H, W)
- 输出维度:
(B, dim_out, H/2, W/2)
1.3 标准化
WeightStandardizedConv2d
:实现了加权标准化卷积层,使用加权标准化(weight standardization)来提高卷积层的训练效率。
class WeightStandardizedConv2d(nn.Conv2d):def forward(self, x):eps = 1e-5 if x.dtype == torch.float32 else 1e-3weight = self.weightmean = reduce(weight, 'o ... -> o 1 1 1', 'mean')var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased=False))normalized_weight = (weight - mean) * (var + eps).rsqrt()return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
1.4 归一化层
LayerNorm
:自定义的层归一化,标准化每个通道的特征图。
class LayerNorm(nn.Module):def __init__(self, dim):super().__init__()self.g = nn.Parameter(torch.ones(1, dim, 1, 1))def forward(self, x):eps = 1e-5 if x.dtype == torch.float32 else 1e-3var = torch.var(x, dim=1, unbiased=False, keepdim=True)mean = torch.mean(x, dim=1, keepdim=True)return (x - mean) * (var + eps).rsqrt() * self.g
PreNorm
:先进行归一化,再传递给后续函数处理。
class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.fn = fnself.norm = LayerNorm(dim)def forward(self, x):x = self.norm(x)return self.fn(x)
1.5 提取函数 extract
def extract(a, t, x_shape): b, *_ = t.shape # 从 t 中获取 batch 大小 b,通常是输入数据的第一个维度out = a.gather(-1, t) # 使用 PyTorch 的 gather 函数,从 a 中按时间步索引 t 提取对应的参数return out.reshape(b, *((1,) * (len(x_shape) - 1))) # 将提取出的参数调整为形状 [B, 1, ..., 1],使得与 x_shape 的维度一致,便于后续广播
这段代码的目的是从参数向量 a
中按时间步索引 t
提取对应的值,并调整其形状为 [b, 1, 1, ..., 1]
,以兼容目标张量 x_shape
的广播操作。
这种形式在扩散模型中非常常见,用于将时间相关的参数扩展到整个张量操作。
输入参数:
a
:表示一个预先计算好的时间步相关参数(如扩散模型中的alphas_cumprod
或snr
),通常是 [T] 的张量。t
:表示当前 batch 中每个样本的时间步索引([B] 的整数张量)。x_shape
:输入数据的形状,用于对结果进行广播。
难点理解 return out.reshape(b, *((1,) * (len(x_shape) - 1)))
:
-
b
- 表示批量大小,通常是输入数据的第一个维度。
- 例如,如果
x_shape = (8, 3, 32, 32)
(批大小为 8),则b = 8
。
-
len(x_shape) - 1
- 表示目标张量的维度数减去 1。
- 例如,对于
x_shape = (8, 3, 32, 32)
,len(x_shape) = 4
,因此len(x_shape) - 1 = 3
。
-
(1,) * (len(x_shape) - 1)
- 生成一个包含
len(x_shape) - 1
个值为1
的元组。 - 例如,如果
len(x_shape) - 1 = 3
,则结果是(1, 1, 1)
。
- 生成一个包含
-
out.reshape(b, *((1,) * (len(x_shape) - 1)))
- 将
out
的形状调整为[b, 1, 1, 1, ...]
,以便其与x_shape
的维度兼容。 b
决定第一个维度,其余维度填充1
,为后续广播做准备。
- 将
在 PyTorch 中,广播机制会自动扩展维度为 1
的张量,以匹配目标张量的维度。例如:
- 如果目标张量的形状是
[8, 3, 32, 32]
,一个形状为[8, 1, 1, 1]
的张量会被广播为[8, 3, 32, 32]
。
【举例说明】
输入数据:
a = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) # [2, 3]
t = torch.tensor([0, 2]) # 每个 batch 的时间步索引
x_shape = (2, 3, 32, 32) # 假设目标张量形状
执行过程:
-
a.gather(-1, t)
gather
根据t
从a
的最后一维提取值。a
的形状是[2, 3]
,它的最后一维大小为3
。
a = [[0.1, 0.2, 0.3], # 第一批次[0.4, 0.5, 0.6]] # 第二批次
- 对第一批次(索引
t[0] = 0
):从[0.1, 0.2, 0.3]
提取第0
个值,得到0.1
。 - 对第二批次(索引
t[1] = 2
):从[0.4, 0.5, 0.6]
提取第2
个值,得到0.6
。 - 结果:
out = [0.1, 0.6]
,形状为[2]
。
-
(1,) * (len(x_shape) - 1)
x_shape = (2, 3, 32, 32)
,因此len(x_shape) - 1 = 3
。- 结果:
(1, 1, 1)
。
-
out.reshape(b, *((1,) * (len(x_shape) - 1)))
b = 2
,结果形状为[2, 1, 1, 1]
。
当 out
(形状为 [2, 1, 1, 1]
)与目标张量(例如 [2, 3, 32, 32]
)相乘时:out
会自动广播为 [2, 3, 32, 32]
。
2. 构建网络块类
2.1 block()
class Block(nn.Module):def __init__(self, dim, dim_out, groups = 8):super().__init__()self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)self.norm = nn.GroupNorm(groups, dim_out)self.act = nn.SiLU()def forward(self, x, scale_shift = None):x = self.proj(x)x = self.norm(x)if exists(scale_shift):scale, shift = scale_shiftx = x * (scale + 1) + shiftx = self.act(x)return x
这段代码定义了一个名为 Block
的类,它继承自 nn.Module
,通常用于构建神经网络模型中的模块。这个 Block
类包括一个卷积层、归一化层、激活函数,以及可选的 scale_shift
操作。
整体作用:
Block
类构建了一个标准的深度学习模块,常用于卷积神经网络中的基本单元。- 它先通过标准化卷积层提取特征,再进行分组归一化,最后使用
SiLU
激活函数进行非线性变换。scale_shift
提供了额外的灵活性,使得模块可以在某些应用中进行动态调整。
初始化方法__init__
方法用于初始化类的实例:
- 参数:
dim
: 输入通道的数量。dim_out
: 输出通道的数量。groups
: 归一化层的分组数,默认为 8。
super().__init__()
调用父类nn.Module
的构造方法。- 初始化的组件:
self.proj
:一个卷积层,使用自定义的WeightStandardizedConv2d
(这是一个标准化权重的卷积层),卷积核大小为 3,填充为 1。self.norm
:nn.GroupNorm
归一化层,使用groups
参数对输入通道进行分组归一化。self.act
:激活函数nn.SiLU()
,一种平滑的激活函数,也称为 Swish 激活函数。
forward
前向方法定义了数据如何通过网络进行传递:
-
参数:
x
:输入张量。scale_shift
:一个可选参数,包含(scale, shift)
,用于缩放和平移x
。
-
前向过程:
x = self.proj(x)
:将输入x
通过卷积层进行卷积操作。x = self.norm(x)
:将卷积后的结果进行分组归一化。- 检查
scale_shift
是否存在(使用exists(scale_shift)
)。如果存在,将scale_shift
拆分为scale
和shift
,其中scale
是缩放因子,shift
是偏移量。并应用以下变换:x = x * (scale + 1) + shift
,将x
进行缩放和偏移调整。 x = self.act(x)
:将结果通过激活函数SiLU
激活。- 返回结果
x
。
2.2 残差连接
Residual
类 :残差连接模块,接受一个函数作为输入, 返回该函数的输出与输入的和。
class Residual(nn.Module):def __init__(self, fn):super().__init__()self.fn = fndef forward(self, x, *args, **kwargs):return self.fn(x, *args, **kwargs) + x
2.3 ResnetBlock
class ResnetBlock(nn.Module):def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):super().__init__()self.mlp = nn.Sequential(nn.SiLU(),nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else Noneself.block1 = Block(dim, dim_out, groups = groups)self.block2 = Block(dim_out, dim_out, groups = groups)self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()def forward(self, x, time_emb = None):scale_shift = Noneif exists(self.mlp) and exists(time_emb):time_emb = self.mlp(time_emb)time_emb = rearrange(time_emb, 'b c -> b c 1 1')scale_shift = time_emb.chunk(2, dim = 1)h = self.block1(x, scale_shift = scale_shift)h = self.block2(h)return h + self.res_conv(x)
这段代码定义了一个 ResnetBlock
类,它是用于深度学习模型中的残差块,继承自 nn.Module
。ResnetBlock
包括两个 Block
层和一个残差连接,并可以在需要时接受时间嵌入(time_emb
)用于调节网络行为。
__init__
方法 初始化组件:
self.mlp
: 一个 MLP 层,用于将时间嵌入投影到dim_out * 2
维度。如果time_emb_dim
存在,则创建此 MLP;否则为None
。self.block1
和self.block2
: 两个Block
实例,分别用于特征提取。self.res_conv
: 一个 1x1 卷积层(或恒等映射),用于调整输入x
与输出h
在维度上的匹配。如果dim
和dim_out
不相同,则使用卷积层进行调整。
具体来说,
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
这两行代码实例化了两个 Block 层,并在 ResnetBlock 中分别赋值给self.block1
和self.block2
。每个 Block 对象的初始化会执行 Block 类的__init__
方法中的内容,
- 当创建
self.block1
时,Block 的__init__
方法执行以下步骤:
self.proj
初始化为一个 WeightStandardizedConv2d 卷积层,它的输入通道数为 dim,输出通道数为 dim_out,卷积核大小为 3,并且带有 1 个像素的填充。
self.norm
初始化为 GroupNorm,用于将输出通道 dim_out 进行分组归一化。
self.act
初始化为 SiLU 激活函数。- 创建
self.block2
时,执行了 Block 的__init__
方法,但与self.block1
的主要区别是:
这次 proj 卷积层的输入通道和输出通道都是 dim_out,使得输出通道数保持不变。
forward
方法:
- 接收输入
x
和可选的时间嵌入time_emb
。 - 如果
self.mlp
存在且time_emb
存在,则将time_emb
通过MLP
,并通过rearrange
重塑为适合卷积操作的形状。然后将其切分为scale
和shift
(用于scale_shift
操作)。 - 通过
block1
进行前向传播,并应用scale_shift
。 - 经过
block2
进一步处理。 - 最后,返回
h + self.res_conv(x)
,实现残差连接。
2.3 Attention
class Attention(nn.Module):def __init__(self, dim, heads = 4, dim_head = 32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)self.to_out = nn.Conv2d(hidden_dim, dim, 1)def forward(self, x):b, c, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, dim = 1)q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)q = q * self.scalesim = einsum('b h d i, b h d j -> b h i j', q, k)attn = sim.softmax(dim = -1)out = einsum('b h i j, b h d j -> b h i d', attn, v)out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)return self.to_out(out)
这段代码定义了一个用于图像数据的自注意力机制模块,Attention
,它继承自 nn.Module
并用于深度学习模型中。
-
__init__
方法:- 接收
dim
(输入通道数)、heads
(多头注意力头的数量)、dim_head
(每个注意力头的维度)作为参数。 - 计算
hidden_dim
为dim_head * heads
,这是多头注意力中每个Q
、K
、V
的总维度。 self.scale
用于缩放Q
向量,以稳定训练。self.to_qkv
是一个1x1
卷积层,将输入的特征图变换为Q
(查询)、K
(键)和V
(值)向量,输出通道数是hidden_dim * 3
。self.to_out
是另一个1x1
卷积层,用于将注意力机制的输出映射回输入的维度dim
。
- 接收
-
forward
方法:- 输入
x
是一个四维张量,形状为(batch_size, channels, height, width)
。 - 使用
self.to_qkv(x)
生成Q
、K
和V
,并通过chunk(3, dim=1)
将它们分开。 - 使用
map()
和rearrange()
将Q
、K
、V
重排,使它们适应多头注意力的形状(batch_size, heads, dim_head, tokens)
,其中tokens = height * width
。 Q
向量乘以self.scale
进行缩放。- 计算相似度矩阵
sim
,通过einsum('b h d i, b h d j -> b h i j', q, k)
实现,表示Q
和K
之间的点积。 - 使用
softmax
计算注意力权重attn
。 - 计算注意力加权后的输出
out
,通过einsum('b h i j, b h d j -> b h i d', attn, v)
将注意力矩阵应用于V
。 - 重排输出形状回
(batch_size, channels, height, width)
。 - 最后,使用
self.to_out(out)
将输出映射回原输入通道数。
- 输入
该模块用于提取图像特征的自注意力机制,帮助模型在处理复杂输入数据时,捕获长距离依赖和上下文信息。