Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

卷积核格式推导

递推格式展开:考虑如下离散后的 S4D 系统,其中

其中 。我们有

卷积核定义:卷积核可以定义为如下

由于 是对角矩阵, 是向量, 是向量,因此 是一个数,可以写为

从而

其中 满足如上式子。

Pytorch 实现:先根据 A, B, C, L 计算出 Kernel K (L,)。然后再转换到频域计算。

import torch
import torch.fft

def compute_ssm_kernel(A, B, C, L):
    """
    步骤 1: 生成卷积核 K
    对应公式: K_t = sum_{n=1}^N C_n * (A_n)^t * B_n
    
    参数:
        A: (N,) 复数对角矩阵的对角元素
        B: (N,) 复数输入投影
        C: (N,) 复数输出投影
        L: 序列长度
    返回:
        K: (L,) 复数卷积核
    """
    # 1. 构造时间步向量 t = [0, 1, ..., L-1]
    # shape: (L)
    t = torch.arange(L, device=A.device)
    
    # 2. 计算 A 的幂次 (A_n)^t
    # 利用广播机制: (N, 1) ** (L) -> (N, L)
    # 这一步计算了所有状态 n 在所有时刻 t 的衰减项
    A_powers = torch.pow(A.unsqueeze(-1), t)
    
    # 3. 计算各项乘积 C_n * (A_n)^t * B_n
    # term shape: (N, L)
    term = (C * B).unsqueeze(-1) * A_powers
    
    # 4. 对状态维度 N 求和 (sum_{n=1}^N)
    # 这一步将 N 个独立状态的响应混合成一个系统的脉冲响应
    K = torch.sum(term, dim=0) 
    
    return K # shape: (L,)

def fft_convolution(x, K):
    """
    步骤 2: FFT 卷积加速
    对应公式: y = x * K
    
    参数:
        x: (Batch, L) 实数输入序列
        K: (L,) 复数卷积核
    返回:
        y: (Batch, L) 实数输出序列
    """
    L = x.shape[-1]
    
    # 1. 确定 FFT 长度 (通常设为 2*L 以避免循环卷积混叠)
    fft_len = 2 * L
    
    # 2. 输入 x (实数) -> 频域
    # 使用 rfft 因为 x 是实数,只计算一半频谱
    x_f = torch.fft.rfft(x, n=fft_len)
    
    # 3. 卷积核 K (复数) -> 频域
    # 使用 fft 因为 K 本身是复数
    k_f = torch.fft.fft(K, n=fft_len)
    
    # 4. 对齐频谱
    # rfft 得到的频谱长度为 fft_len//2 + 1
    # 我们需要截取 k_f 对应的前半部分正频率
    k_f = k_f[..., :x_f.shape[-1]]
    
    # 5. 频域乘法 (对应时域卷积)
    y_f = x_f * k_f
    
    # 6. 逆变换回时域 + 截断
    y = torch.fft.irfft(y_f, n=fft_len)
    y = y[..., :L] # 去掉补零产生的部分
    
    return y