Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers

news/2024/10/5 5:16:59

目录
  • 符号说明
  • LSSL
    • 和其它方法的联系
  • 代码

Gu A., Johnson I., Goel K., Saab K., Dao T., Rudra A., and Re C. Combining recurrent, convolutional, and continuous-time models with linear state-space layers. NeurIPS, 2021.

State space representaion-wiki.

Mamba 系列的第二作: LSSL.

符号说明

  • \(u(t) \in \mathbb{R}\), 输入信号;
  • \(x(t) \in \mathbb{R}^N\), 中间状态;
  • \(y(t) \in \mathbb{R}\), 输出信号

LSSL

  • 从 LSSL 开始, 作者开始围绕 linear system 做文章:

    \[\tag{1} \dot{x}(t) = A x(t) + B u(t), \\ y(t) = C x(t) + Du(t), \]

    注意, 这里作者把 \(A, B, C, D\) 简化为和时间 \(t\) 无关的量, 且仅仅讨论的是一维的信号.

  • 采用 generalized bilinear transform (GBT) 可以将上述的 ODE 离散化:

    \[x_t = \bar{A} x_{t-1} + \bar{B} u_t, \\ y_t = C x_t + D u_t, \]

    其中

    \[\bar{A} = (I - \alpha \Delta t \cdot A)^{-1} (I + (1 - \alpha) \Delta t \cdot A), \\ \bar{B} = \Delta t (I - \alpha \Delta t \cdot A)^{-1} B, \]

    \(\Delta t\) 是时间间隔, 而 \(\alpha\) 是一个 bilinear 的超参数. 具体的推导可以见 here.

  • 我们现在仅关注 \(x_t\), 然后看看一个具体的例子. 取 \(A=-1, B=1, \alpha=1, \Delta t = \exp(z)\), 我们有

    \[x_t = \frac{1}{1 + \exp(z)} x_{t-1} + \frac{\exp(z)}{1 + \exp(z)} u_t = (1 - \sigma(z)) x_{t-1} + \sigma(z) u_t. \]

    这实际上就是一个 gating 机制 (常常用在 RNN 的更新上).

  • ok, 对于 RNN, 我们可以把其中的一层看出是对 (1) 的一次近似, 那么多层的效果是什么? 作者认为这和 Picard iteration 有关系, Picard iteration, 即

    \[x_{i+1}(t) := x_i (t_0) + \int_{t_0}^t f(s, x_i(s)) ds \]

    可以证明随着 \(i\) 的增加, 会逐步收敛到真实解, 换言之, 多层的叠加可以让误差越来越小.

  • Deep LSSLs:

    • LSSLs 的具体构造就是上述离散过程的叠加, 同时不同 block 之间添加 skip connection 和 layer norm.
    • 假设我们的输入信号是 \(\mathbb{R}^{L \times H}\) 的, 其中 \(L\) 表示序列长度, \(H\) 是维度, 此时信号不是 1 维的. 作者的做法是, 为每个维度单独设立:

    \[A \in \mathbb{R}^{N \times N}, \quad B \in \mathbb{R}^{N \times 1}, \quad C \in \mathbb{R}^{1 \times N}, \quad D \in \mathbb{R}, \quad \Delta t \in \mathbb{R} \]

    分别进行上述的离散过程. 此外, 这里需要注意的是, \(\Delta t\) 我们也是可学习的.

    • 作者还提到, 输出信号 \(y(t)\) 不一定必须是 1 维的, 也可以是 \(M\) 维的, 此时 \(C \in \mathbb{R}^{M \times N}, D \in \mathbb{R}^{M \times 1}\). 这会导致最后的输出维度是 \(H\cdot M\), 可以通过 MLP 映射回 \(H\).
      所以总共的参数量为:

    \[HNN + HN1 + HMN + HM + H + HMH = \mathcal{O}(HN^2 + HMN + H^2M). \]

注: 作者好像把 LSSL 的代码删掉了, 不过我注意到后续的 S4 的设定里面, 是为每个维度单独设立 \(A, B, C, D\) 还是共享是可以选择的 (也可以是部分维度共享, 取决于 n_ssm 这个参数).

和其它方法的联系

  • LSSL 除了可以看成是 RNN 外, 实际上还具有卷积的特性, 容易发现:

    \[\begin{array}{ll} y_k &= C(\bar{A})^k \bar{B} u_0 + C (\bar{A})^{k-1} \bar{B} u_1 + \cdots + C \overline{AB} u_{k-1} + \bar{B} u_k + D u_k \\ &= \sum_{s} C(\bar{A})^{k-s} u_s \\ &= \mathcal{K}_L (\bar{A}, \bar{B}, C) * u + Du \end{array}, \]

    其中

    \[\mathcal{K}_L (A, B, C) := (CB, CAB, \ldots, CA^{L-1}B). \]

  • 所以, LSSL 具有卷积的优点, 可以并行计算.

代码

[official-code]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hjln.cn/news/44213.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

堆基础知识

arenachunk通俗地说,一块由分配器分配的内存块叫做一个 chunk,包含了元数据和用户数据。具体一点,chunk 完整定义如下: struct malloc_chunk {INTERNAL_SIZE_T mchunk_prev_size; /* Size of previous chunk (if free). */INTERNAL_SIZE_T mchunk_size; …

【Azure Spring Apps】Spring App部署上云遇见 502 Bad Gateway nginx

问题描述 在部署Azure Spring App应用后,访问应用,遇见了502 Bad Gateway Nginx。问题解答 502 Bad Gateway, 并且由Nginx返回。而自己的应用中,并没有定义Nginx相关内容,所以需要查看问题是否出现在Azure Spring App服务的设置上。 根据Spring App的通信模型图判断,502的…

学生管理系统的CRUD

include using namespace std; typedef struct Studnet { //初始化结构体变量 int ID; double math_scores; double english_scores; double computer_scores; double total_scores;}Student; void Input_student_score(int size, Student* stu); //输入所有学生信息 void Out…

C语言中关于Base64编码的基础原理

Base64编码简述: 1.Base64是网络上最常见的用于传输8Bit字节码的编码方式之一,Base64就是一种基于64个可打印字符来表示二进制数据的方法。 2.Base64,就是包括小写字母a-z、大写字母A-Z、数字0-9、符号"+"、"/"一共64个字符的字符集,(任何符号都可以转…

09-盒子模型

盒子模型01 认识盒子模型02 盒子模型的四边03 盒子边框04 盒子内边距-padding 通常用于设置边框和内容之间的间距 <!DOCTYPE html> <html lang="en"> <head><meta charset="UTF-8"><meta http-equiv="X-UA-Compatible&quo…

试了下ocr

pdf能看了,拓展的驱动下,想着是否可以ORC呢,识别到文字内容更有帮助。 按网搜的顺序,开始是用pytesseract,pip安装顺利,但运行不了,提示找不到pytesseract,按网上的帮助下载win安装包,选上中文包,再试,可以运行了,就是中文基本识别不了,也不知哪里改善,只得作罢。…

fastjson1

@目录前言分析复制文件清空文件出现问题和分析问题解决分析问题再次出现问题再次分析最终结果读取文件分析poc拓宽场景极限环境poc优化修改再次优化poc的分析写入文件SafeFileOutputStream写文件java8无依赖读文件在commons-io库下的写入文件原因利用链分析组合poc出现问题和分…

解决运行loadRunner报错无法进行代理的错误

选择第二个,不设置代理,可以实现回放不会报错,但是今日运行遇到错误,无法实现全部的录制脚本回访完毕,卡住打开网址处的脚本。直接运行完毕,不会报错。