3. Selective State Space Models
3.1 Selection as a Means of Compression
๋ณํฉ ์์ ์ ๊ดํ ๋๊ฐ์ง ์คํ ์์
- Selective Copying : ๊ธฐ์ตํ ํ ํฐ์ ์์น๋ฅผ ๋ฐ๊ฟ Copying Task๋ฅผ ์์ ํ๋ค. ๊ด๋ จ์๋ ํ ํฐ์ ๊ธฐ์ตํ๊ณ ๊ด๋ จ์๋ ํ ํฐ์ ๊ฑธ๋ฌ๋ด๋ ค๋ฉด ๋ด์ฉ ์ธ์ ์ถ๋ก (content-aware resoning)์ด ํ์ํ๋ค.
- Induction Heads : ์ ์ ํ ์ปจํ ์คํธ์์ ์ถ๋ ฅ์ ๋ผ ์๊ธฐ๋ฅผ ์๊ธฐ ์ํด์๋ ๋ด์ฉ ์ธ์ ์ถ๋ก ์ด ํ์ํ๋ค. LLM์ ๋์ ๊ณผ์ ์ค๋ช ์ ์ํด ๊ฐ์ฅ ๋ง์ด ์ฐ์ด๋ ๋งค์ปค๋์ฆ.
์ด ์์
์ LTI ๋ชจ๋ธ์ ์คํจํ ๋ชจ๋๋ฅผ ๋ณด์ฌ์ค๋ค. ํ๊ท์ ๊ด์ ์์ constant dynamics(์ฌ๊ธฐ์์๋ $\bar{A}, \bar{B}$)๋ context์์ ์ฌ๋ฐ๋ฅธ ์ ๋ณด๋ฅผ ์ ํํ๊ฑฐ๋, ์
๋ ฅ์ ๋ฐ๋ฅธ ๋ฐฉ์์ผ๋ก ์ํ์ค๋ฅผ ๋ฐ๋ผ hidden state์ ์ํฅ์ ์ค ์ ์๋ค.
ํฉ์ฑ๊ณฑ์ ๊ด์ ์์ global convolutions๋ ์ผ๋ฐ copying task๋ฅผ ํด๊ฒฐํ ์ ์๋ค๊ณ ์๋ ค์ ธ ์๋๋ฐ, ์๋ํ๋ฉด ๊ทธ๊ฒ ์๊ฐ ์ธ์ ๊ธฐ๋ฐ์ด๊ธฐ ๋๋ฌธ์ด๋ค. ๊ทธ๋ฌ๋ ๋ด์ฉ ์ธ์ ๋ถ์กฑ์ ๋ฌธ์ ๊ฐ ์์ด์ Selective Copying task์๋ ์ด๋ ค์์ด ์๋ค.
์ํ์ค ๋ชจ๋ธ์ ํจ์จ์ฑ๊ณผ ํจ๊ณผ์ฑ(์ผ๋ง๋ ํจ์จ์ ์ผ๋ก ๋ชฉํ๋ฅผ ๋ฌ์ฑํ๋๊ฐ vs ๋ชฉํ๋ฅผ ์ผ๋ง๋ ์ฑ๊ณต์ ์ผ๋ก ๋ฌ์ฑํ๋๊ฐ)์ ํธ๋ ์ด๋ ์คํ๋ ์ผ๋ง๋ ์์ ์ state๋ฅผ ์ ์์ถํ๋๊ฐ๋ก ํํ๋๋ค. ํจ์จ์ ์ธ ๋ชจ๋ธ์ ๋ฐ๋์ ์์ state๋ฅผ ๊ฐ์ ธ์ผํ๊ณ , ๋ฐ๋ฉด์ ํจ๊ณผ์ ์ธ ๋ชจ๋ธ์ context์์ ๋ชจ๋ ํ์ํ ์ ๋ณด๋ฅผ state์ ํฌํจํด์ผํ๋ค.
์ด ๋ ผ๋ฌธ์์ ์ํ์ค ๋ชจ๋ธ ์๋ฆฝ์ ํ์ํ ๊ธฐ๋ณธ ์์น์ ์ง์ค์ ์ํ ๋ด์ฉ์ธ์ ํน์ ์ ๋ ฅ์ sequential state๋ก ํํฐ๋งํ ์ง์ ๋ํ ์ ํ์ฑ Selectivity๋ก ์ ์ํ๋ค. ๋ถ๋ถ์ ์ผ๋ก selection mechanism์ ์ ๋ณด ์ ํ ํน์ ์ํ์ค ์ฐจ์์ ๋ํ ์ํธ์์ฉ์ ์ ์ดํ๋ค.
3.2 Improving SSMs with Selection
Selection ๋งค์ปค๋์ฆ์ ๋ชจ๋ธ์ ์ ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ํ์ค๋ฅผ ๋ฐ๋ผ ์ํธ์์ฉ์ ์ํฅ์ ๋ฏธ์น๋ ํ๋ผ๋ฏธํฐ๋ค์ ์ ๋ ฅ์ ๋ฐ๋ผ ์์ง์ด๊ฒ ํ๋ ๊ฒ์ด๋ค.
๊ฐ์ฅ ๋ค๋ฅธ ์ ์ SSM๊ณผ ๋ฌ๋ฆฌ $\Delta, B, C$ ํ๋ผ๋ฏธํฐ๋ฅผ input์ ๋ํด ํจ์๋ก ์ ์ํ๊ณ , tensor์ shape์ ๋ฐ๊พผ ๊ฒ์ด๋ค. ๋ถ๋ถ์ ์ผ๋ก๋ ์ด ํ๋ผ๋ฏธํฐ๋ค์ด ์ด์ ๊ธธ์ด ์ฐจ์ L์ ๊ฐ์ง๋ค๋ ๊ฒ์ธ๋ฐ, ์ด๋ ๋ชจ๋ธ์ด ์๋ถ๋ณ์์ ์๊ฐ ๋ณํ๋ฅผ ๋ฐ์ํ๋ ์ชฝ์ผ๋ก ๋ฐ๋์๋ค๋ ๋ป์ด๋ค.
$$\begin{matrix}
s_{B}(x) &=& Linear_{N}(x) \
s_{C}(x) &=& Linear_{N}(x) \
s_{\Delta}(x) &=& Broadcast_{D}(Linear_{1}(x) \
\tau_{\Delta} &=& softplus
\end{matrix}$$
$Linear_{d}$ ๋ ํ๋ผ๋ฏธํฐํ ๋ ์ฐจ์ $d$ ์ projection์ด๋ค. $s_{\Delta}$ ์ $\tau_{\Delta}$ ์ ์ ํ์ RNN gating ๋งค์ปค๋์ฆ๊ณผ์ ์ฐ๊ฒฐ ๋๋ฌธ์ด๋ค.
3.3 Efficient Implementation of Selective SSMs
์ ํ ๋งค์ปค๋์ฆ์ ๊ฝค ์์ฐ์ค๋ฝ๊ณ , ์ด์ ์ฐ๊ตฌ์์๋ recurrent SSMs์์ $\Delta$๋ฅผ ์๊ฐ์ด ์ง๋ ์๋ก ๋ณํ์ํค๋ ๊ฒ๊ณผ ๊ฐ์ ์ ํ์ ํน๋ณํ ์ผ์ด์ค๋ฅผ ํตํฉํ๋ ค๊ณ ๋ ธ๋ ฅํ๋ค. ๊ทธ๋ฌ๋ SSMs ์ฌ์ฉ์ ์ฃผ์ ํ๊ณ๋ ์ฐ์์ ํจ์จ์ฑ์ ์๋ค. ์ด๋ S4๋ ๋ค๋ฅธ ํ์ ์๊ณ ๋ฆฌ์ฆ์ด LTI(non-selective) ๋ชจ๋ธ์, ๊ทธ ์ค์์๋ global convolution ํํ์ ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ์ด์ ์ด๋ค.
3.3.1 Motivation of Prior Models
- SSMs์ ๊ฐ์ recurrent model์ ํํ๋ ฅ๊ณผ ์๋ ์ฌ์ด์ ๊ท ํ์ ์ก์์ผํ๋ค.
- recurrent ๋ชจ๋์์ ํ์๋ ๊ฒ์ด convolution ๋ชจ๋๋ผ์ recurrent๊ฐ ์ข ๋ ์ ์ฐํ๋ค. ๊ทธ๋ฌ๋ ์ด ๊ณผ์ ์์ (B, L, D, C) ์ฐจ์์ latent state h๋ฅผ ๊ณ์ฐํ๊ณ ๊ตฌ์ฒดํํด์ผํ๋๋ฐ, ์ด๊ฒ ์ ๋ ฅ x์ ์ถ๋ ฅ y์ ์ฐจ์๋ณด๋ค ํจ์ฌ ํฌ๋ค. ๊ทธ๋์ state ์ฐ์ฐ์ ์ฐํํ๊ณ convolution kernel(B, L, D)๋ฅผ ๊ตฌ์ฒดํํ๋ convolution mode๊ฐ ๋ ํจํผ์ ์ด๋ผ๊ณ ์๊ฐ๋๋ค.
- ์ด์ LTI state space model ์ ์ด์ค ์ํ-ํฉ์ฑ๊ณฑ ํํ๋ฅผ ์ฌ์ฉํด์ ํจ์จ์ฑ ์ ํ ์์ด ๊ธฐ์กด RNN๋ณด๋ค 10-100๋ฐฐ์ ๋ฌํ๋ state ์ฐจ์์ ๋๋ฆฐ๋ค.