卷积核格式推导
递推格式展开:考虑如下离散后的 S4D 系统,其中 ,
其中 ,,,,,。我们有
卷积核定义:卷积核可以定义为如下
由于 是对角矩阵, 是向量, 是向量,因此 是一个数,可以写为
从而
其中 满足如上式子。
Pytorch 实现:先根据 A, B, C, L 计算出 Kernel K (L,)。然后再转换到频域计算。
import torch
import torch.fft
def compute_ssm_kernel(A, B, C, L):
"""
步骤 1: 生成卷积核 K
对应公式: K_t = sum_{n=1}^N C_n * (A_n)^t * B_n
参数:
A: (N,) 复数对角矩阵的对角元素
B: (N,) 复数输入投影
C: (N,) 复数输出投影
L: 序列长度
返回:
K: (L,) 复数卷积核
"""
# 1. 构造时间步向量 t = [0, 1, ..., L-1]
# shape: (L)
t = torch.arange(L, device=A.device)
# 2. 计算 A 的幂次 (A_n)^t
# 利用广播机制: (N, 1) ** (L) -> (N, L)
# 这一步计算了所有状态 n 在所有时刻 t 的衰减项
A_powers = torch.pow(A.unsqueeze(-1), t)
# 3. 计算各项乘积 C_n * (A_n)^t * B_n
# term shape: (N, L)
term = (C * B).unsqueeze(-1) * A_powers
# 4. 对状态维度 N 求和 (sum_{n=1}^N)
# 这一步将 N 个独立状态的响应混合成一个系统的脉冲响应
K = torch.sum(term, dim=0)
return K # shape: (L,)
def fft_convolution(x, K):
"""
步骤 2: FFT 卷积加速
对应公式: y = x * K
参数:
x: (Batch, L) 实数输入序列
K: (L,) 复数卷积核
返回:
y: (Batch, L) 实数输出序列
"""
L = x.shape[-1]
# 1. 确定 FFT 长度 (通常设为 2*L 以避免循环卷积混叠)
fft_len = 2 * L
# 2. 输入 x (实数) -> 频域
# 使用 rfft 因为 x 是实数,只计算一半频谱
x_f = torch.fft.rfft(x, n=fft_len)
# 3. 卷积核 K (复数) -> 频域
# 使用 fft 因为 K 本身是复数
k_f = torch.fft.fft(K, n=fft_len)
# 4. 对齐频谱
# rfft 得到的频谱长度为 fft_len//2 + 1
# 我们需要截取 k_f 对应的前半部分正频率
k_f = k_f[..., :x_f.shape[-1]]
# 5. 频域乘法 (对应时域卷积)
y_f = x_f * k_f
# 6. 逆变换回时域 + 截断
y = torch.fft.irfft(y_f, n=fft_len)
y = y[..., :L] # 去掉补零产生的部分
return y