一 源码分析
1.1 函数入口
void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O, int stages) {CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]const int d = Q.size(3); // B, H, N, dif (stages > 1) {switch (d){case 32:launch_flash_attn_mma_stages_split_q_shared_kv<32, 2>(Q, K, V, O);