2149 words
11 minutes
关于SSM的数学推导

连续和离散形式的状态方程#

1. 初始形式:连续时间状态方程#

连续时间下的 SSM 状态方程为:

h˙(t)=Ah(t)+Bx(t)\dot{h}(t) = A h(t) + B x(t)

为了使用代数方法推导,我们先对其进行 拉普拉斯变换(Laplace Transform),从时间域 tt 转换到复频域 ss。假设初始状态为 0,则有:

sH(s)=AH(s)+BX(s)s H(s) = A H(s) + B X(s)

移项整理得到:

(sIA)H(s)=BX(s)(sI - A) H(s) = B X(s)

2. 双线性变换公式(sszz 的代换)#

双线性变换的核心思想是用 zz 域的表达式来近似 ss 域的微分算子。公式如下:

s2Δ1z11+z1s \approx \frac{2}{\Delta} \frac{1 - z^{-1}}{1 + z^{-1}}

其中:

  • Δ\Delta (或写作 Δt\Delta t) 是采样步长(Time step)。
  • z1z^{-1} 是单位延迟算子,对应时域中的 ht1h_{t-1}

3. 推导过程#

ss 的表达式代入连续系统的拉普拉斯方程中:

(2Δ1z11+z1IA)H(z)=BX(z)\left( \frac{2}{\Delta} \frac{1 - z^{-1}}{1 + z^{-1}} I - A \right) H(z) = B X(z)

我们的目标是解出 H(z)H(z)X(z)X(z) 的关系,并将其转化为时域递归形式 ht=Aˉht1+Bˉxth_t = \bar{A}h_{t-1} + \bar{B}x_t

第一步:消除分母#

为了方便计算,方程两边同时左乘 (1+z1)I(1 + z^{-1})I(注意矩阵乘法顺序,虽此处 II 可交换):

[2Δ(1z1)IA(1+z1)]H(z)=B(1+z1)X(z)\left[ \frac{2}{\Delta} (1 - z^{-1}) I - A (1 + z^{-1}) \right] H(z) = B (1 + z^{-1}) X(z)

第二步:按 zz 的幂次分组#

将含有 z1z^{-1}(过去时刻)和常数项(当前时刻)分开:

[(2ΔIA)(2ΔI+A)z1]H(z)=B(1+z1)X(z)\left[ (\frac{2}{\Delta} I - A) - (\frac{2}{\Delta} I + A) z^{-1} \right] H(z) = B (1 + z^{-1}) X(z)

展开方程左边:

(2ΔIA)H(z)(2ΔI+A)z1H(z)=BX(z)+Bz1X(z)(\frac{2}{\Delta} I - A) H(z) - (\frac{2}{\Delta} I + A) z^{-1} H(z) = B X(z) + B z^{-1} X(z)

第三步:转回离散时间域#

利用 ZZ 变换的性质:H(z)htH(z) \to h_tz1H(z)ht1z^{-1}H(z) \to h_{t-1}。 将上式写回时域差分方程:

(2ΔIA)ht(2ΔI+A)ht1=Bxt+Bxt1(\frac{2}{\Delta} I - A) h_t - (\frac{2}{\Delta} I + A) h_{t-1} = B x_t + B x_{t-1}

第四步:求解 hth_t#

hth_t 保留在左边,其余项移到右边:

(2ΔIA)ht=(2ΔI+A)ht1+B(xt+xt1)(\frac{2}{\Delta} I - A) h_t = (\frac{2}{\Delta} I + A) h_{t-1} + B (x_t + x_{t-1})

为了简化系数,我们先对等式两边同时乘以 Δ2\frac{\Delta}{2}

(IΔ2A)ht=(I+Δ2A)ht1+Δ2B(xt+xt1)(I - \frac{\Delta}{2} A) h_t = (I + \frac{\Delta}{2} A) h_{t-1} + \frac{\Delta}{2} B (x_t + x_{t-1})

最后,左乘 (IΔ2A)1(I - \frac{\Delta}{2} A)^{-1} 以孤立 hth_t

ht=(IΔ2A)1(I+Δ2A)Aˉht1+(IΔ2A)1Δ2BInput coeff(xt+xt1)h_t = \underbrace{(I - \frac{\Delta}{2} A)^{-1} (I + \frac{\Delta}{2} A)}_{\bar{A}} h_{t-1} + \underbrace{(I - \frac{\Delta}{2} A)^{-1} \frac{\Delta}{2} B}_{\text{Input coeff}} (x_t + x_{t-1})

4. 结果整理与 SSM 中的惯用形式#

经过上述严谨推导,我们得到了双线性变换后的精确递归公式:

  1. 离散化状态矩阵 Aˉ\bar{A}

    Aˉ=(IΔ2A)1(I+Δ2A)\bar{A} = (I - \frac{\Delta}{2} A)^{-1} (I + \frac{\Delta}{2} A)
  2. 关于输入的离散化: 在严格的双线性变换推导中,输入项变成了 Δ2Bˉraw(xt+xt1)\frac{\Delta}{2} \bar{B}_{raw} (x_t + x_{t-1}),这意味着当前状态取决于当前输入和上一步输入。

    然而,在 S4、Mamba 等深度学习 SSM 的文献和实现中,为了保持形式简洁(即标准的 RNN 形式 ht=Aˉht1+Bˉxth_t = \bar{A}h_{t-1} + \bar{B}x_t),通常会做以下处理之一:

    • 近似处理:忽略 xt1x_{t-1} 的影响(或认为在采样间隔内输入变化不大)。
    • 定义 Bˉ\bar{B}:直接定义离散后的输入矩阵 Bˉ\bar{B} 吸收系数。

    最常见的 SSM 离散化结论(如 S4 论文中的公式)是:

    Bˉ=(IΔ2A)1ΔB\bar{B} = (I - \frac{\Delta}{2} A)^{-1} \Delta B

    (注意这里是 ΔB\Delta B 而不是 Δ2B\frac{\Delta}{2} B,这通常来自于假设输入在区间内为常数或者是对输入项系数的一种归一化定义)

总结#

通过 s2Δ1z11+z1s \to \frac{2}{\Delta} \frac{1 - z^{-1}}{1 + z^{-1}} 的代换,我们推导出了 SSM 的离散化参数:

Aˉ=(IΔ/2A)1(I+Δ/2A)Bˉ(IΔ/2A)1ΔB\begin{aligned} \bar{A} &= (I - \Delta/2 \cdot A)^{-1}(I + \Delta/2 \cdot A) \\ \bar{B} &\approx (I - \Delta/2 \cdot A)^{-1} \Delta B \end{aligned}

最终离散状态方程为:

ht=Aˉht1+Bˉxth_t = \bar{A} h_{t-1} + \bar{B} x_t

SSM卷积核(频谱)的计算#

在卷积表示中,需要进行一个很大的因果卷积。为了快速进行这个卷积,需要使用方法快速算出SSM卷积核的频谱。SSM卷积核的频谱,即为截断生成函数在一系列单位根上的取值。

在 S4 (Structured State Space sequence Models) 论文中,计算卷积核的截断生成函数(即系统的频率响应)时,核心难点在于计算 Resolvent(预解集),即形如 (sIA)1(sI - A)^{-1}(IAˉz)1(I - \bar{A}z)^{-1} 的项。

直接计算这个逆矩阵的复杂度是 O(N3)O(N^3),这在状态维度 NN 较大时是不可接受的。S4 通过 DPLR (Diagonal Plus Low Rank) 结构分解和 Woodbury 矩阵恒等式 巧妙地规避了直接求逆。

下面是详细的数学推导过程。


1. 问题的定义#

SSM 的卷积核 Kˉ\bar{K} 的生成函数(Z 变换)即为该线性系统的传递函数:

H(z)=Cˉ(IAˉz)1Bˉ\mathcal{H}(z) = \bar{C} (I - \bar{A} z)^{-1} \bar{B}

我们需要在 zz 取单位圆上的 LL 个点(即 zk=exp(i2πkL)z_k = \exp(-i \frac{2\pi k}{L}),对应 FFT 的频率点)时计算该式的值。

如果 Aˉ\bar{A} 是一个一般的 N×NN \times N 矩阵,对每个频率点求逆的代价极大。

2. 关键结构:DPLR (Diagonal Plus Low Rank)#

S4 论文证明了,虽然最优的状态矩阵 AA(即 HiPPO 矩阵)不是对角阵,但它们都可以被分解为 “对角矩阵 + 低秩矩阵” 的形式(或者可以通过相似变换转化为这种形式)。

假设连续时间状态矩阵 AA 具有如下形式(秩为 1 的 DPLR):

A=ΛPQA = \Lambda - P Q^\top

其中:

  • Λ\Lambda 是对角矩阵(Diagonal),Λ=diag(λ1,,λN)\Lambda = \text{diag}(\lambda_1, \dots, \lambda_N)
  • P,QRN×1P, Q \in \mathbb{R}^{N \times 1} 是列向量(Low Rank, rank=1)。

3. 连接离散与连续:从 Aˉ\bar{A}AA#

我们在前一个问题中推导过双线性变换:

Aˉ=(IΔ2A)1(I+Δ2A)\bar{A} = (I - \frac{\Delta}{2} A)^{-1} (I + \frac{\Delta}{2} A)

直接求 (IAˉz)1(I - \bar{A}z)^{-1} 很麻烦。S4 的技巧在于利用双线性变换的关系,将离散域的求逆问题转化为连续域的求逆问题。

s=2Δ1z1+zs = \frac{2}{\Delta} \frac{1-z}{1+z},我们可以证明(推导略繁,但这只是代数代换),离散传递函数与连续传递函数有如下关系:

Cˉ(IAˉz)1BˉconstC(sIA)1B\bar{C} (I - \bar{A} z)^{-1} \bar{B} \approx \text{const} \cdot C (s I - A)^{-1} B

(注:严格来说系数会有变化,且 Cˉ,Bˉ\bar{C}, \bar{B}C,BC, B 有变换关系,但核心计算瓶颈完全取决于如何计算 (sIA)1(sI - A)^{-1})

因此,问题被归约为:如何快速计算 (sIA)1(sI - A)^{-1},其中 ss 是标量,A=ΛPQA = \Lambda - P Q^\top

4. 核心推导:Woodbury 矩阵恒等式#

我们需要计算:

Y=(sI(ΛPQ))1Y = (s I - (\Lambda - P Q^\top))^{-1}

即:

Y=(sIΛD+PQ)1Y = (\underbrace{s I - \Lambda}_{D} + P Q^\top)^{-1}

这里令 D=sIΛD = s I - \Lambda。因为 Λ\Lambda 是对角阵,DD 也是对角阵,且其逆矩阵 D1D^{-1} 极其容易计算(只需对角线元素取倒数),复杂度为 O(N)O(N)

根据 Woodbury 矩阵恒等式(Matrix Inversion Lemma):

(D+UV)1=D1D1U(I+VD1U)1VD1(D + UV^\top)^{-1} = D^{-1} - D^{-1} U (I + V^\top D^{-1} U)^{-1} V^\top D^{-1}

D=sIΛ,U=P,V=QD=sI-\Lambda, U=P, V=Q 代入上式:

(sIA)1=D1D1P(1+QD1P)1scalar inverseQD1(sI - A)^{-1} = D^{-1} - D^{-1} P \underbrace{(1 + Q^\top D^{-1} P)^{-1}}_{\text{scalar inverse}} Q^\top D^{-1}

复杂度分析(为何变快了?):#

  1. D1D^{-1}DD 是对角阵,求逆只需 O(N)O(N)
  2. QD1PQ^\top D^{-1} P
    • D1PD^{-1} P 是对角阵乘向量,复杂度 O(N)O(N)
    • 再乘 QQ^\top(向量点积),复杂度 O(N)O(N)
    • 结果是一个标量(或者在秩为 rr 时是 r×rr \times r 小矩阵)。
  3. 求逆:对标量(或小矩阵)求逆,复杂度 O(1)O(1)
  4. 其余乘法:都是向量与对角阵的运算,均为 O(N)O(N)

结论:通过 Woodbury 恒等式,我们将 O(N3)O(N^3) 的矩阵求逆降低到了 O(N)O(N)

5. 进一步:柯西核 (Cauchy Kernel) 与 FFT#

虽然我们解决了单点求逆的问题,但我们需要计算整个序列(比如长度 LL)的卷积核。这意味着我们需要对 LL 个不同的 zkz_k(对应 LL 个不同的 sks_k)进行上述计算。

将上述 Woodbury 展开式代入传递函数 H(s)=C(sIA)1BH(s) = C (sI - A)^{-1} B,最终形式会变成形如下式的求和:

kj=n=1Ncnbnsjλn×CorrectionTermk_j = \sum_{n=1}^N \frac{c_n b_n}{s_j - \lambda_n} \times \text{CorrectionTerm}

这里的结构看起来像:

xiyjsjλi\sum \frac{x_i y_j}{s_j - \lambda_i}

这被称为 Cauchy Matrix(柯西矩阵) 乘法。

S4 的另一个主要贡献是指出这种形式的计算可以通过 FFT(快速傅里叶变换)O(N+LlogL)O(N + L \log L) 的时间内完成,而不是朴素的 O(NL)O(N \cdot L)

总结#

S4 避免直接计算逆矩阵的数学推导逻辑链如下:

  1. 分解:利用 HiPPO 矩阵性质,将 AA 分解为 对角阵 + 低秩阵 (A=ΛPQA = \Lambda - PQ^\top)。
  2. 转化:利用双线性变换,将离散系统 (IAˉz)1(I - \bar{A}z)^{-1} 的计算转化为连续系统预解集 (sIA)1(sI - A)^{-1} 的计算。
  3. 降维:使用 Woodbury 矩阵恒等式,利用 Λ\Lambda 易于求逆的特点,将 N×NN \times N 矩阵的求逆运算转化为对角矩阵运算和标量(或低秩矩阵)求逆,将单点计算复杂度从 O(N3)O(N^3) 降至 O(N)O(N)
  4. 加速:结合 FFT 算法处理所有频率点,实现全局高效计算。