litgpt框架笔记
litgpt的fsdp执行原理
python __main__.py finetune_full meta-llama/Lflama-2-7b-hf --config /home/xingzhuang/llm/litgpt/config_hub/finetune/llama-2-7b/full.yaml
( __main__.py 在litgpt/litgpt目录下)
执行该命令大致流程:
-
先讲讲full.yaml 中的global_batch_size和micro_batch_size参数的含义
- global_batch_size表示optimizer做一次step的总batch数,global_batch_size会均分给所有GPU,不妨记为local_batch,当某个GPU完成了自己的local_batch后optimizer才能做step更新参数
- micro_batch_size,每个GPU会将自己的local_batch进一步拆分成micro_batch,拆分大小为micro_batch_size
-
大致执行流程,主要在litgpt/finetune/full.py文件的fit函数中
-
batch = next(train_iterator)每次拿到一个micro_batch做forward
-
is_accumulating表示本轮micro_batch forward完成后,该GPU是否完成了local_batch
-
is_accumulating参数会传给fabric.no_backward_sync判断本轮forward对应的backward是否需要同步其他GPU的local_batch的梯度,其实就是保证local_batch累加的梯度都是自身local_batch的梯度
- 若is_accumulating为True表示该GPU还未完成local_batch,所以不需要同步其他GPU的local_batch的梯度(具体来讲,就是当某个GPU拉取某个layer的全部权重并算出该layer的梯度后,并不将梯度scatter给其他的GPU)
- 若is_accumulating为False表示该GPU已完成local_batch,所以会同步其他GPU的local_batch梯度
-
当所有GPU都完成了自己的local_batch后,则会执行optimizer.step()做一次梯度优化
pytorch-lightning的fsdp+tp原理
python train.py
(train.py在pytorch-lightning/examples/fabric/tensor_parallel/train.py)
执行该命令大致执行流程如下:
litgpt适配fsdp+tp
1.把pytorch-lightning/examples/tensor_parallel文件夹下的parallelism.py和model.py复制到litgpt/litgpt/finetune/下
2. 把full.py中的strategy改为
strategy = ModelParallelStrategy(# User-defined function that applies the desired parallelizations specific to the model# (TP, FSDP2, activation checkpointing, ...)parallelize_fn=parallelize,# Define the size of the 2D parallelism# Set to "auto" to apply TP intra-node and DP inter-nodedata_parallel_size="auto", tensor_parallel_size="auto",)
3.在litgpt/litgpt/model.py下的class CausalSelfAttention的__init__.py方法中加上代码
self.n_heads = config.n_head
self.n_kv_heads = config.n_head
4.修改parallelism.py文件
5.修改litgpt/litgpt/utils.py的load_checkpoint方法
def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:if isinstance(fabric.strategy, FSDPStrategy):fabric.load_raw(checkpoint_path, model, strict=strict)elif isinstance(fabric.strategy, ModelParallelStrategy):fabric.load_raw(checkpoint_path, model, strict=False)else:state_dict = lazy_load(checkpoint_path)state_dict = state_dict.get("model", state_dict)model.load_state_dict(state_dict, strict=strict)
6.修改litgpt/litgpt/model.py 下的class CausalSelfAttention:
把self.attn改成self.attn_w
7.有个包貌似有问题
/home/xingzhuang/software/anaconda3/envs/litgpt/lib/python3.9/site-packages/torch/distributed/tensor/parallel/api.py
临时解决方法:把/home/xingzhuang/software/anaconda3/envs/litgpt/lib/python3.9/site-packages/torch/distributed/tensor/parallel/style.py的_apply函数中
NotImplementedError改为print,不终止报错