๋์จ์ง ๋ฒ์จ 1๋ ๋ ๋์์ง๋ง ์ต์ ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ๋ฅผ ์ ํ์ง๊ฐ ๋ฐฑ๋ง๋ ์ ๋ ๋ ๊ฒ ๊ฐ์์ ํ ๋ฒ ์ฝ์ด๋ณด๋ mamba... ๊ทธ๋ฆฌ ์ ํํ ๋ฆฌ๋ทฐ๋ ์๋ ์ ์์ต๋๋ค. ์ฌ์ค ๋ฒ์ญ์ ๊ฐ๊น๊ณ ์ข ๋ ์ดํดํด๋ณด๋ฉด์ ๋ด์ฉ ์์ ํ๊ฒ ์ต๋๋ค.
1. Introduction
์ต๊ทผ๋ค์ด Structured State Space Sequence Models(SSMs) ๊ฐ ์ํ์ค ๋ชจ๋ธ๋ง ๊ตฌ์กฐ ๋ถ์ผ์์ ์ ๋งํ ํด๋์ค๋ก ๋ฑ์ฅํ๋ค. ์ ํต์ ์ธ state space models์ ์๊ฐ์ ๋ฐ์ CNN๊ณผ RNN์ ํตํฉ์ ์กฐ์จํ๋ค(interpreted).
Mamba์์๋ selective state space model์ ์๋ก์ด ์ข
๋ฅ๋ฅผ ์ ์ํ๋ค. ์ํ์ค ๊ธธ์ด์ ๋ฐ๋ผ ์ ํ์ ์ผ๋ก ํ์ฅํ๋ฉด์ transformer์ ๋ชจ๋ธ๋ง ํ์๋ฅผ ๋ฐ๋ผ์ก๊ธฐ ์ํด์ ๋ช๋ช์ axes(์ฌ๊ธฐ์์๋ ๊ธฐ์ค? ์ถ.)์ ๋ํ ๊ธฐ์กด ์ฐ๊ตฌ๋ฅผ ๋ฐ์ ์ํฌ ๊ฒ์ด๋ค.
Selection Mechanism
- ๊ธฐ์กด ๋ชจ๋ธ์ ์ฃผ์ ํ๊ณ๋ ๋ฐ์ดํฐ๋ฅผ ์ ๋ ฅ ์์กด์ ๋ฐฉ์์ผ๋ก, ํจ์จ์ ์ผ๋ก select ํ๋ ๋ฅ๋ ฅ์ ๊ฐ์ง๋ค๋ ์ ์ด๋ค.
- Mamba์์๋ ์ ํ์ ๋ณต์ฌ์ ํค๋ ์ ๋์ ๊ฐ์ ์ค์ํ ํฉ์ฑ ์์ ์ ๋ฐํ์ผ๋ก ํ ์ง๊ด์ ๊ธฐ๋ฐ์ผ๋ก ์ ๋ ฅ์ ๊ธฐ๋ฐํ SSM ํ๋ผ๋ฏธํฐ๋ฅผ ๋งค๊ฐ๋ณ์ํํ์ฌ ๊ฐ๋จํ ์ ํ ๋งค์ปค๋์ฆ์ ์ค๊ณํ์๋ค.
- ์ด๋ ๋ชจ๋ธ์ด ๊ด๋ จ์ด ์ ์ ์ ๋ณด๋ฅผ ํํฐ๋งํ๊ณ ๊ด๋ จ์๋ ์ ๋ณด๋ฅผ ๊ธฐ์ตํ๋๋ก ํ๋ค.
Hardware-aware Algorithm
Architecture
2023๋
๋ฐํ๋ SSM Acritecture๊ณผ Transformer์ MLP block์ single block์ผ๋ก ํ์ฌ ์ ํ์ state space ๋ฅผ ํตํฉํ ๋จ์ํ๊ณ ํตํฉ์ ์ธ ์ํคํ
์ณ ๋์์ธ(Mamba)๋ฅผ ์ค๊ณํ์๋ค.
Selective SSMs๋ fully recurrent models๋ก ์ํ์ค์์ ๋์ํ๋ ์ผ๋ฐ์ ์ธ foundation models์ ๋ฐฑ๋ณธ์ด ๋๊ธฐ ์ ํฉํ key properties๊ฐ ์๋ค.
- high quality
- fast training and inference
- long context
์ด ๋ ผ๋ฌธ์์๋ ๋ช๊ฐ์ง ํ์ ์ modalities์ ์ค์ ์ ๋ฐ๋ฅธ pretraining quality์ domain-specificํ ์ฑ๋ฅ ํ๊ฐ๋ก Mamba์ ์ผ๋ฐ์ ์ธ ์ํ์ค FM backbone์ผ๋ก์์ ์ ์ฌ๋ ฅ์ ํ๊ฐํ๋ค.
2. SSM(State Space Model)
Structured state space sequence models(S4)๋ RNN, CNN, ๊ทธ๋ฆฌ๊ณ ์ ํต์ ์ธ state space model๊ณผ ์ฐ๊ด์ด ์๋ ๋ฅ๋ฌ๋์ ์ํ ์ต์ ์ํ์ค ๋ชจ๋ธ์ ํ ์ข ๋ฅ์ด๋ค. ์ด๋1์ฐจ์ ํจ์ ํน์ ์ํ์ค๋ฅผ mappingํ๋, $x(t) \in R \to y(t) \in R$ implicit latent state $h(t) \in R^{N}$ ๋ถ๋ถ์ ์ฐ์ ์์คํ ์ด๋ค.
S4 ๋ชจ๋ธ์ sequence-to-sequence transformation์ ๋ ๋จ๊ณ๋ก ์ ์ํ๋ 4๊ฐ์ง ํ๋ผ๋ฏธํฐ($\Delta, A, B, C$) ๋ก ์ ์๋๋ค.
์ด์ฐํ Discretization
์ฐ์ ํ๋ผ๋ฏธํฐ($\Delta, A, B$)๋ฅผ ์ด์ง ํ๋ผ๋ฏธํฐ($\bar{A}, \bar{B}$)๋ก ๋ณ๊ฒฝํ๋ค. ์ด๋ ๋ณํ ๋ฐฉ๋ฒ์ zero-order hold(ZOH)์ ๋ฐ๋ฅธ๋ค.
์ด์ฐํ๋ ํด์๋ ๋ถ๋ณ๊ณผ ๊ฐ์ ์ถ๊ฐ์ ์ธ ์์ฑ์ ๋ถ์ฌํ ์ ์๊ณ , ๋ชจ๋ธ์ด ์ ์ ํ๊ฒ ์ ๊ทํ๋์๋์ง ์๋์ผ๋ก ํ์ธํ๋ ๋ฑ ์ฐ์ ์๊ฐ ์์คํ ๊ณผ ๊ธด๋ฐํ ๊ด๊ณ๊ฐ ์๋ค. ๋ํ RNN์ gating ๋ฉ์ปค๋์ฆ๊ณผ๋ ์ฐ๊ด์ด ์๋ค. ๊ทธ๋ฌ๋ ๊ธฐ๊ณ์ ๊ด์ ์์ ์ด์ฐํ๋ ๋จ์ํ ssm์ forward pass์ ๊ทธ๋ํ ๊ณ์ฐ์ ์ฒซ ๋จ๊ณ๋ก ๋ณด์ผ ์ ์๋ค. SSM์ ๋์ฒด ๋ฒ์ ์ ์ด์ฐํ ๋จ๊ณ ๋์ ํ๋ผ๋ฏธํฐ ($\bar{A}, \bar{B}$)๋ฅผ ์ง์ ๋์ฒดํ ์ ์๋๋ฐ, ์ด๋ ๋ ์ฌ์ด ๋ฐฉ๋ฒ์ด ๋ ์๋ ์๋ค.
Computation
ํ๋ผ๋ฏธํฐ๋ฅผ ์ด์ง ํ๋ผ๋ฏธํฐ๋ก ๋ณํํ ๋ค์, ๋ชจ๋ธ์ ๋ ๊ฐ์ง ๋ฐฉ๋ฒ์ผ๋ก ๊ณ์ฐ๋ ์ ์๋๋ฐ ๋ฐ๋ก Linear recurrence*์ *global convolution ์ด๋ค.
๋์ฒด๋ก ๋ชจ๋ธ์ ํจ์จ์ ์ธ ๋ณ๋ ฌ ํ์ต์ ์ํด convolutional mode๋ฅผ ์ฌ์ฉํ๋ค๊ฐ ํจ์จ์ ์ธ ํ๊ท์ถ๋ก ์ ์ํด recurrent ๋ฐฉ์์ ์ฌ์ฉํ๋ค.
์ ํ์๋ถ๋ณ Linear Time Invariance(LTI)
์ฌ๊ธฐ์์๋ $(\Delta, A, B, C)$์ $(\bar{A}, \bar{B})$๊ฐ ๋ชจ๋ time steps์ ๊ณ ์ ๋๋ค. ์ด๋ฌํ ํน์ฑ์ Linear Time Invariance๋ผ๊ณ ํ๋ค. ์ด๋ recurrence์ convolution์ ๊น์ ์ฐ๊ด์ด ์๋ค. ํธํ๊ฒ LTI SSMs๋ฅผ ์ด๋ค ์ ํ ํ๊ท๋ ํฉ์ฑ๊ณฑ๊ณผ ๋๋ฑํ๊ฒ ์๊ฐํ ์ ์๊ณ , LTI๋ฅผ ์ด๋ฐ ์ข ๋ฅ์ ๋ชจ๋ธ ํด๋์ค์ ๋ํ ํฌ๊ด์ ์ฉ์ด๋ผ๊ณ ์๊ฐํด๋ ๋๊ฒ ๋ค.
๊ทผ๋ณธ์ ์ธ ํจ์จ์ ์ ์ฝ์ผ๋ก ์ธํด ๋ชจ๋ SSMs ๊ตฌ์กฐ๋ LTI์ด๋ค. ๊ทธ๋ฌ๋ ์ค์ํ ์์ฌ์ ์ LTI ๋ชจ๋ธ๋ค์ด ํน์ ํ ํ์ ์ ๋ฐ์ดํฐ์ ๋ํ ๋ชจ๋ธ๋ง์ ๋ํ ๊ทผ๋ณธ์ ์ธ ํ๊ณ๊ฐ ์๋ค๋ ๊ฒ์ด๋ค. ๊ทธ๋ฆฌ๊ณ ์ฐ๋ฆฌ์ ๊ธฐ์ ์ ๊ธฐ์ฌ๋ LTI ์ ์ฝ์ ์ญ์ ํ๊ณ ํจ์จ์ ์ธ ๋ณ๋ชฉํ์์ ๊ทน๋ณต์ ๊ด๋ จ์ด ์๋ค...
Structure and Dimensions
๊ฐ์ฅ ์ธ๊ธฐ์๋ structure ํํ๋ diagonal์ด๋ค. ์ด ๊ฒฝ์ฐ $A \in \mathbb{R}^{N \times N}, B \in \mathbb{R}^{N \times 1}, C \in \mathbb{R}^{1 \times N}$ ํ๋ ฌ๋ค์ ๋ชจ๋ N ์ซ์๋ก ํํ๋ ์ ์๋ค. Batch size B, Length L, channel D์ input sequence x๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด์๋ SSM์ด ๊ฐ channel์ ๋ ๋ฆฝ์ ์ผ๋ก ์ ์ฉ๋๋ค. ์ด ๊ฒฝ์ฐ์๋ ์ ์ฒด hidden state๋ ์ ๋ ฅ๋ง๋ค $DN$ ์ฐจ์์ ๊ฐ์ง๊ณ , ๋ชจ๋ sequence ๊ธธ์ด์ ๋ฐ๋ฅธ ๊ณ์ฐ์ $O(BLDN)$ ์๊ฐ๊ณผ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํ๋ค.
General State Space Models
- Markov Decision Process(MDP)
- Dynamic Casual Modeling(DCM)
- Kalman Filters
- Hidden Markov Models(HMM)
- Linear Dynamical Systems(LDS)
- Recurrent models at large deep learning
SSM Architectures
- Linear attention
- H3
- Hyena
- RetNet
- RWKV