Cholesky Decomposition

11k 词

Cholesky 分解在干什么

Cholesky 分解是将一个对称、正定的矩阵 AA,分解为一个下三角矩阵 LL 及其转置 LTL^T 乘积:

A=LLTA = LL^T

在 GPU 上为了实现并行加速,我们不能逐元素计算,而是采用分块迭代的策略。每一轮迭代中,我们只处理矩阵的一个“步长”。

假设当前迭代处理到矩阵的某个位置,我们将矩阵 AA 划分为如下三类区域:

  • A11A_{11}:当前步长下的对角小方块,大小通常为 nb×nbnb \times nb(如 128×128128 \times 128)。
  • A21A_{21}:位于 A11A_{11} 下方的瘦长矩阵,宽度为 nbnb,高度为剩余矩阵的高度。
  • A22A_{22}:右下角剩余的大方阵。

对应的下三角矩阵 LL 也对应分成四个部分,最终得到:

[A11A21TA21A22]=[L110L21L22][L11TL21T0L22T]\begin{bmatrix} A_{11} & A_{21}^T \\ A_{21} & A_{22} \end{bmatrix} = \begin{bmatrix} L_{11} & 0 \\ L_{21} & L_{22} \end{bmatrix} \begin{bmatrix} L_{11}^T & L_{21}^T \\ 0 & L_{22}^T \end{bmatrix}

根据块的计算过程,我们可以得到三个等式:

A11=L11L11T+00=L11L11TA21=L21L11T+L220=L21L11TA22=L21L21T+L22L22TA_{11} = L_{11} L_{11}^T + 0 \cdot 0 = L_{11} L_{11}^T\\ A_{21} = L_{21} L_{11}^T + L_{22} \cdot 0 = L_{21} L_{11}^T\\ A_{22} = L_{21} L_{21}^T + L_{22} L_{22}^T

为了求出 LL,分成三个步骤,反复执行:

  1. POTRF (对角块分解):A11CholeskyL11A_{11} \xrightarrow{Cholesky} L_{11}.
    • 对左上角的小方块 A11A_{11} 进行 Cholesky 分解,得到 L11L_{11}.
  2. TRSM(三角方程求解): L21=A21(L11T)1L_{21} = A_{21}(L_{11}^T)^{-1}.
    • 利用刚才算出的 L11L_{11},去更新它下方的长条矩阵 A21A_{21},算出 L21L_{21}.
  3. SYRK (对称秩-K 更新): A22=A22L21L21TA_{22}' = A_{22} - L_{21}L_{21}^T
    • 用算出来的 L21L_{21} 把右下角剩下的那一块大矩阵 A22A_{22} 给“削掉”一层。

完成这三步后,我们就确定了 LL 的一列“块”。下一轮迭代,我们将从更新后的 A22A_{22}' 的左上角开始,重复上述过程,直到整个矩阵被处理完。

分块迭代 Cholesky 分解

通过上节的认识,可以明确分块 Cholesky 分解并不是一次性完成的,而不是一个典型的迭代过程。

在每一轮迭代中,矩阵被划分为三个区域:

  • 已处理区域:矩阵的左上方,已经变成了最终的 LL 的一部分。
  • 当前工作区:包含当前对角块 A11A_{11} 和下方长条 A21A_{21}.
  • 待处理区域:右下角的 A22A_{22} 部分,它将在下一轮变成新的工作区。

每一轮完成 L11L_{11} L21L_{21} 的计算,并留下 L22L_{22} 在下轮计算。

假设矩阵总大小为 NN,分块大小为 nbnb,当前处理到第 ii 个块:

  1. 当前块的起始地址:
    • A11A_{11} 的起始地址位于 (i×nb,i×nb)(i\times nb, i\times nb),大小为 nb×nbnb \times nb.
    • A21A_{21} 的起始地址位于 ((i+1)×nb,i×nb)((i+1)\times nb, i\times nb),大小为(N(i+1)×nb)×nb(N-(i+1)\times nb)\times nb.
  2. POTRF:对 A11A_{11} 进行分解,这是一个小规模的 Cholesky 分解:A11L11A_{11}\rightarrow L_{11}.
  3. TRSM:利用 L11L_{11} 的三角特性,解出下方整列 L21L_{21}L21=A21(L11T)1L_{21} = A_{21} (L_{11}^T)^{-1}.
  4. SYRK:利用刚刚算出的 L21L_{21} 更新 A22A_{22}A22=A22L21L21TA_{22} = A_{22} - L_{21} L_{21}^T.

大概的逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
for (int i = 0; i < N; i += nb) {
// 1. 处理对角块 (POTRF)
int current_nb = min(nb, N - i);
chol_factorize(A(i, i), current_nb);

if (i + current_nb < N) {
// 2. 更新下方长条 (TRSM)
trsm_solve(A(i, i), A(i + current_nb, i), current_nb, N - i - current_nb);

// 3. 更新右下角剩余矩阵 (SYRK)
syrk_update(A(i + current_nb, i), A(i + current_nb, i + current_nb), current_nb, N - i - current_nb);
}
}

迭代算法的核心就是:

  1. 定义步长:设定一个固定的分块大小 nbnb(Tile Size)。
  2. 局部处理:在当前的剩余矩阵(也就是你说的 A22A_{22})中,锁定左上角那块 nb×nbnb \times nb 的区域作为新的 A11A_{11}
  3. 计算与更新:分解这个 A11A_{11}。利用结果去更新它下方的“长条”和右下角的“大饼”。
  4. 窗口滑动:完成这一轮后,算法的“视角”向右下方移动 nbnb 个坐标,把上一轮的 A22A_{22} 当作新舞台,重复步骤 2。

Cholesky 在 GPU 上的性能瓶颈

迭代算法强行将矩阵切割为宽度固定为 nbnb 的碎片。这种设计在 CPU 时代是为了适配缓存(Cache),但在 GPU 时代却成了桎梏。由于 nbnb 通常很小(如 64 或 128),它强制后续所有算子(TRSM、SYRK)都在极其“瘦长”的空间内执行。

在这种架构下,TRSM 被迫处理高度极高但宽度极窄的“纸片矩阵”,SYRK 则被困在“窄 K 维”的计算中。这导致 Tensor Cores 这种为了大规模方阵乘法而生的硬件,由于任务宽度不足,根本无法形成有效的计算瓦片(Tiles),硬件利用率被强行锁死在极低的水平。

但在解决迭代算法之前,我们先介绍两个算子的传统计算过程。

TRSM 具体算法与瓶颈分析

在实际的计算中,我们不会真的像前文写的 L21=A21(L11T)1L_{21} = A_{21} (L_{11}^T)^{-1} 来算 L21L_{21}.

而是利用:L21L11=A21L_{21}L_{11} = A_{21} 这个等式来求方程解出 L21L_{21}.

为了便于表述,现在使用 XX 替代 L21L_{21} 作为待解矩阵,使用 UU 代替 L11TL_{11}^T 作为已知的上三角矩阵,使用 AA 作为 A21A_{21} 作为结果;并设 XX 形状为 M×NM\times N,则 UU 形状为 NNAAXX 形状相同。

我们观察 AA 的第一列数值是如何计算得到的:

A11=X1,:×U:,1A12=X2,:×U:,1A1m=Xm,:×U:,1A_{11} = X_{1,:} \times U_{:,1} \\ A_{12} = X_{2,:} \times U_{:,1} \\ \dots \\ A_{1m} = X_{m,:} \times U_{:,1}

UU 是上三角矩阵,第一列除了 U11U_{11} 其余均为0,和 XX 任意行相乘得到的实际上是 XX 该行第一个元素与 U11U_{11} 的乘积,我们可以发现:

A11=X11×U11A12=X21×U11A1m=Xm1×U11A_{11} = X_{11}\times U_{11}\\ A_{12} = X_{21}\times U_{11}\\ \dots \\ A_{1m} = X_{m1}\times U_{11}

如果将 XX 以多个列向量分块得到 [x1,x2,,xn][x_1, x_2, \dots, x_n],将 AA 以多个列向量分块得到 [a1,a2,,an][a_1, a_2, \dots ,a_n],那么 AA 的第一列的结果实际上可以表示为:

x1×u11=a1x_1\times u_{11} = a_1

所以我们可以得到:

x1=b1/u11x_1 = b_1/u_{11}

第二列:

A21=X1,:×U:,2A22=X2,:×U:,2A2m=Xm,:×U:,2A_{21} = X_{1,:} \times U_{:,2} \\ A_{22} = X_{2,:} \times U_{:,2} \\ \dots \\ A_{2m} = X_{m,:} \times U_{:,2}

由于 UU 第二列有效元素有两个:u12u_{12} u22u_{22},那么第二列 AA 的计算等式发生了改变:

A21=X11×U12+X12×U22A22=X21×U12+X22×U22A2m=Xm1×U12+Xm2×U22A_{21} = X_{11}\times U_{12}+X_{12}\times U_{22}\\ A_{22} = X_{21}\times U_{12}+X_{22}\times U_{22}\\ \dots \\ A_{2m} = X_{m1}\times U_{12}+X_{m2}\times U_{22}

可以得到:

x1u12+x2u22=a2x_1u_{12} + x_2u_{22} = a_2

解出:$$x_2 = (a_2 - x_1u_{12})/u_{22}$$
这需要依赖 x1x_1 的结果。

第3列:

x1u13+x2u23+x3u33=a3x3=(a3x1u13x2u23)/u33x_1u_{13} + x_2u_{23} + x_3u_{33} = a_3\\ x_3 = (a_3-x_1u_{13}-x_2u_{23})/u_{33}

这需要依赖 x1,x2x_1,x_2 的结果。
jj 列:

aj=x1u1j+x2u2j+...+xjujja_j = x_1 u_{1j} + x_2 u_{2j} + ... + x_j u_{jj}

这需要依赖 x1,x2,xj1x_1,x_2,\dots x_{j-1} 的结果。
这种计算导致计算后序列的结果必须依赖前列。

xjujjx_j u_{jj} 孤立出来:

xjujj=aj(x1u1j+x2u2j++xj1uj1,j)x_j u_{jj} = a_j - (x_1 u_{1j} + x_2 u_{2j} + \dots + x_{j-1} u_{j-1,j})

即 $$x_j u_{jj} = a_j - \sum_{k=1}^{j-1} x_k u_{kj}$$

在实现中,为了节省内存,我们通常采用原地计算的方式:直接在 aja_j 原有的数值上减去之前算好的 xkx_kukju_{kj} 的贡献。

即:

  1. 算出 x1x_1 \rightarrow 立即更新 a2,a3,,ana_2, a_3, \dots, a_n.
  2. 算出 x2x_2 \rightarrow 立即更新 a3,a4,,ana_3, a_4, \dots, a_n.

虽然这看起来很“并行”(因为右侧所有列可以同时减),但仍然会有不少的性能损耗:

  • 无法填充 Tensor Cores:Tensor Cores 的硬件设计是为 16×1616 \times 16 (或更高) 的矩阵块乘法设计的。
    • TRSM 在列与列之间是串行的。
    • 如果你一列一列算,每次送入 Tensor Core 的数据可能只是 M×1M \times 1,硬件利用率低。
  • 瘦矩阵问题:在计算 x2,x3x_2, x_3 \dots 时,虽然 x1u1jx_1 u_{1j} 看起来像向量乘法,但因为它太“瘦”了,无法在 GPU 上形成有效的线程块布局,导致线程束(Warp)内的大量分支分叉或闲置。
  • 累加项的串行开销:随着 jj 的增加,减法项(累加求和)越来越多。
    • 这种累加在标准实现中往往涉及大量的中间变量存取。如果没有将其转化为 GEMM (矩阵乘法),你就无法利用 GPU 专门为累加设计的硬件优化。

总的来说,这是一种 算完一列 更新一堆 的算法。

SYRK 具体算法与瓶颈分析

在完成 TRSM 阶段得到下三角块 L21L_{21} 后,Cholesky 分解进入了计算量最大的阶段:对称秩-K 更新(SYRK)。其数学表达式为:

A22=A22L21L21TA_{22} = A_{22} - L_{21}L_{21}^T

其中,A22A_{22} 是待更新的右下角对称矩阵,L21L_{21} 是在 TRSM 阶段求得的瘦长矩阵。

为了方便说明,假设 A22A_{22} 被划分为 2×22 \times 2 个子块(每个块的大小为 nb×nbnb \times nb),整体结构如下:

A22=[C11C12C21C22],L21=[B1B2]A_{22} = \begin{bmatrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{bmatrix}, \quad L_{21} = \begin{bmatrix} B_{1} \\ B_{2} \end{bmatrix}

对角块 C11C_{11} 的更新(严格对称更新)计算公式:$$C_{11} = C_{11} - B_1 B_1^T$$

计算逻辑:由于 C11C_{11} 位于主对角线上,且 Cholesky 只存储下三角,算法必须逐元素判断:

  • 对于 C11C_{11} 中的任意元素 cpqc_{pq},只有当 pqp \ge q 时才执行乘加运算。
  • 当计算块跨越对角线时,线程需要通过条件判断来决定是否执行计算或写回,这会导致线程束内出现分支分叉,从而降低执行效率;同时也破坏了 Tensor Core 对规则矩阵乘法的高效调度。

非对角块 C21C_{21} 的更新(标准矩形更新)计算公式:$$C_{21} = C_{21} - B_2 B_1^T$$

计算逻辑:这是 SYRK 中最接近标准 GEMM 的部分。

  • 因为 C21C_{21} 完全处于下三角区域,不涉及对角线切割,所以它可以被看作一个 nb×nbnb \times nb 的完整矩阵乘法。
  • 在这个块上,Tensor Core 可以跑出较高的效率。

对角块 C22C_{22} 的更新(严格对称更新)计算公式:$$C_{22} = C_{22} - B_2 B_2^T$$

  • 计算逻辑:同 C11C_{11},这又是一个需要严格处理边界的对称更新。

在传统的 GPU 库中,SYRK 的计算遵循以下逻辑:

  • 对称性约束:由于 A22A_{22} 是对称矩阵,为了减少浮点运算量(理论上减少一半),算法通常只计算并更新 A22A_{22} 的下三角部分。
  • 分块更新:将 A22A_{22} 划分为多个大小为 nb×nbnb \times nb 的子块。对于每一个子块 CijC_{ij}(满足 iji \ge j),计算 Cij=Cijk=1nbLi,k(LT)k,jC_{ij} = C_{ij} - \sum_{k=1}^{nb} L_{i,k} \cdot (L^T)_{k,j}
  • 对角块特殊处理:对于位于 A22A_{22} 主对角线上的子块,计算时需要严格遵守下三角存储格式,仅更新元素 apqa_{pq} 满足 pqp \ge q 的部分。

传统算法利用对称性,通过 SYRK 只更新矩阵的下三角部分,理论上减少了一半的运算量,但这种“减少计算量”的优化并不等价于“提升性能”,因为 GPU 更依赖规则的数据布局和大规模并行计算来发挥性能。

此外,在 Cholesky 分解中,L21L_{21} 的列数仅为分块大小 nbnb,通常远小于矩阵的高度(例如 nb=128nb=128,而 MM 可达数千)。因此在计算 BiBiTB_i B_i^T 时,其本质为:

(M×nb)(nb×M)(M×nb)⋅(nb×M)

虽然结果矩阵规模较大,但其计算受限于较小的 nbnb 维度,导致计算呈现“窄 K 维”的特征。这种形状难以充分利用 Tensor Core 对大规模方阵乘法的优化能力,从而进一步限制了 SYRK 的性能表现。

从迭代到递归

递归 TRSM:消除列依赖

将待解的瘦长矩阵 X 和已知的三角阵 U 进行水平二分:

[X1X2][U11U120U22]=[A1A2]\begin{bmatrix} X_1 & X_2 \end{bmatrix} \begin{bmatrix} U_{11} & U_{12} \\ 0 & U_{22} \end{bmatrix} = \begin{bmatrix} A_1 & A_2 \end{bmatrix}

其中

  • X1,X2X_1, X_2:形状均为 M×(N/2)M \times (N/2);作为一个整体,成为待求解的 L21L_{21}.
  • U11,U22U_{11}, U_{22}:形状均为 (N/2)×(N/2)(N/2) \times (N/2) 的上三角矩阵。
  • U12U_{12}:形状为 (N/2)×(N/2)(N/2) \times (N/2) 的一般方阵。

算法步骤:

  1. X1X_1
  2. 更新 A2A_2
  3. X2X_2

求解出 X1,X2X_1 ,X_2,我们就得到了 L21L_{21}.

X1X_1

X1U11=A1X_1 U_{11} = A_1

这是一个规模缩小了一般的子 TRSM 问题。我们不需要知道右边发生了什么,先通过递归求解出 X1X_1.

更新 A2A_2

由矩阵乘法可知:X1U12+X2U22=A2X_1 U_{12} + X_2 U_{22} = A_2.
变换可得:

A^2=A2X1U12\hat{A}_2 = A_2 - X_1 U_{12}

这是一个非常关键的更新:在迭代算法中,这一步对应多次细碎的列向量更新。

在这里,由于我们一次性算出来一整个 X1X_1 块,这一步就变成了一个标准的 GEMM.

X2X_2

X2U22=A^2X_2 U_{22} = \hat{A}_2

此时 A2A_2 已经减去了所有来自 X1X_1 的影响。我们只需要递归求解这个子 TRSM 即可得到 X2X_2.

递归 SYRK:引入更多的 GEMM

在之前的 SYRK 实现中,问题的核心在于:为了节省一半的理论计算量(对称性),我们牺牲了硬件最擅长的规则并行性。

我们将 L21L_{21} 水平切分为上下两块 B1B_1B2B_2A22A_{22} 则根据对称性划分为三个部分(只关注下三角):

L21=[B1B2],A22=[C11C21C22]L_{21} = \begin{bmatrix} B_1 \\ B_2 \end{bmatrix}, \quad A_{22} = \begin{bmatrix} C_{11} & \\ C_{21} & C_{22} \end{bmatrix}

更新公式 A22=A22L21L21TA_{22} = A_{22} - L_{21}L_{21}^T 展开后变为:

  • C11=C11B1B1TC_{11} = C_{11} - B_1 B_1^T (规模减半的子 SYRK 问题)
  • C21=C21B2B1TC_{21} = C_{21} - B_2 B_1^T (标准 GEMM 问题)
  • C22=C22B2B2TC_{22} = C_{22} - B_2 B_2^T (规模减半的子 SYRK 问题)

根据递归逻辑,执行过程如下:

  1. 左上子块递归:对 C11C_{11} 调用递归 SYRK。当块大小降低到硬件阈值(如 nb=16nb=163232)时,执行底层优化的小规模 SYRK Kernel。
  2. 中间块集火更新(核心步骤):计算 C21=C21B2B1TC_{21} = C_{21} - B_2 B_1^T。注意:这里的 C21C_{21} 是一个完整的矩形块,不涉及任何对角线。因此,我们可以直接调用 GPU 上最强力的 GEMM 接口。随着矩阵总规模 NN 的增大,这一步转化的 GEMM 占据了绝大部分的计算量。
  3. 右下子块递归:对 C22C_{22} 调用递归 SYRK,逻辑同步骤 1。

在递归 SYRK 中,计算被重组为:

  • 两个规模减半的子问题(递归 SYRK)。
  • 一个大规模 GEMM(C21C_{21} 更新)。

随着问题规模的增大,这三部分的计算占比发生显著变化:

  • 递归部分:规模不断减小,占比逐渐降低。
  • GEMM 部分:始终作用于较大的矩形块,占据主导地位。

整体计算逐渐转化为:以 GEMM 为核心的大规模计算 + 少量对称处理。

递归 Cholesky 分解算法

我们将待分解的对称正定矩阵 AA 沿对角线平分为四个大块(每个大块占据原矩阵 1/41/4 的面积):

A=[A11A21TA21A22]A = \begin{bmatrix} A_{11} & A_{21}^T \\ A_{21} & A_{22} \end{bmatrix}

与迭代法每次只步进 nbnb 不同,这里的 A11A_{11} 大小通常是当前矩阵规模的 1/2。整个分解过程遵循以下四个递归步骤:

  1. 左上角递归分解 (Recursive POTRF):对 A11A_{11} 再次调用 recursive_potrf。它会继续二分,直到矩阵规模缩小到硬件甜点区,然后调用最底层的非分块算法直接算出 L11L_{11}
  2. 求解左下角长条 (Recursive TRSM):利用算出的 L11L_{11},通过我们之前介绍的递归 TRSM 算法求解 L21L_{21}(即 L21=A21L11TL_{21} = A_{21} L_{11}^{-T})。这一步会将大量的三角求解转化为满血的 GEMM。
  3. 更新右下角剩余块 (Recursive SYRK):利用 L21L_{21},通过我们之前介绍的递归 SYRK 算法更新 A22A_{22}(即 A22=A22L21L21TA_{22} = A_{22} - L_{21} L_{21}^T)。这一步是计算量最大的,也是 GEMM 占比最高的部分。
  4. 右下角递归分解 (Recursive POTRF):对更新后的 A22A_{22} 再次调用 recursive_potrf,重复上述过程。

递归算法将 A11A_{11} 设为矩阵的一半(N/2N/2)。这意味着在递归回溯的过程中,产生的 TRSM 和 SYRK 规模是与原矩阵规模成比例的。这种“方阵化”的转变,让任务形状从“瘦长”变成了“大方块”,从而能无缝对接 Tensor Cores 最擅长的 16×1616 \times 16 或更高维度的矩阵乘法。

总结

  • 递归化从根本上改变了计算形态,使原本“瘦长”的 L21L_{21} 以及窄 K 维的 SYRK 结构,逐步演化为更适合 GPU 的大规模方块计算,从而提升硬件利用率。
  • 递归 TRSM 通过分治策略打破了列间依赖,将原本强串行的三角求解过程重构为“局部递归 + 全局 GEMM 更新”的混合计算模式,其中关键依赖被转化为可并行执行的矩阵乘法。
  • 递归 SYRK 则利用结构拆分,将对称约束下的计算重新组织为“两个递归子问题 + 一个主导 GEMM 更新”,随着规模增大,计算逐渐被 GEMM 主导,对称性约束的影响被弱化。

在整体递归 Cholesky 框架中,核心计算逐步统一到高吞吐的 GEMM 形式,使算法从“算子组合驱动”转向“以 GEMM 为中心的计算重构”,从而更充分发挥 Tensor Core 的硬件潜力

留言