Read `RWKV Reinventing RNNs for the Transformer Era`

the icon is from https://wiki.rwkv.com/, all rights reserves to its onwer.

RWKV: Reinventing RNNs for the Transformer Era

Paper link: https://arxiv.org/abs/2305.13048
Homepage: https://wiki.rwkv.com/
There are many projects inside the homepage.

TL;DR

Transformer suffers from memory and computational complexity that scales quadratically with sequence length. Recurrent neural networks (RNNs) has linear scaling in memory and computational requirements but the recurrent mechanism prevents the parallelization and scalability. The proposed Receptance Weighted Key Value (RWKV) intended to combine the efficient parallelization and scalability and reserve the efficient inference of RNNs. The motivation, implementation and application is predominantly in Large Language Model (LLMs). There is a community that has trained several LLMs. Those models are public available. More information can be found in there homepage. I think this is good trend, which makes the techniques for LLM accessible for more.

The core design of the RWKV relies on the equation (14) of this paper , which can be expressed recursively. The inference part can take advantange of the recursive mechanism and produce next state from the current state (a property of RNNs), without re-compute each state from scratch.

The interpolation equations (11-13) are not that intutitive for me. Authors may can give more explanation. It makes the current \(\mathbf{r}\), \(\mathbf{k}\), and \(\mathbf{v}\) aware of both current and previous input. A natual question is why is two steps. Is it necessary to have two steps as the equation (14) has a mechanism to mix different steps.

This is an interesting paper that collaborated with a large amount of authors. The techniques are envolving very fast. The line between research and engineering in this field are more likely blurred. The mechanism of co-evolving shortened the period from ideas to products.

D1

Quadratic complexity

In self-attention mechanism, the Key \(\mathbf{k}\), Query \(\mathbf{q}\) and Value \(\mathbf{v}\) are of length \(\mathrm{T}\). The attention operator can be written as:

\begin{equation} \label{eq:vanilla-self-attn} \mathrm{Attn(Q, K, V)}_{t}= \frac{\sum_{i=1}^{T} e^{q_{t}^{\intercal} k_{i}}v_{i}}{\sum_{i=1}^{T} e^{q_{t}^{\intercal} k_{i}}} \end{equation}

It used the same idea in Attention Free Transformer (AFT), the alternative formulation is,

\begin{equation} \label{eq:aft-attn} \mathrm{Attn^{+}(W, K, V)}_{t}= \frac{\sum_{i=1}^{T} e^{w_{t,i}+ k_{i}}v_{i}}{\sum_{i=1}^{T} e^{w_{t,i}+ k_{i}}}, \end{equation}

where \(w_{t,i} \in R^{T \times T}\) is kind of offset learned during training and each \(w_{t,i}\) is a number.

The RWKV makes it simpler, instead of learn the matrix offset in AFT, it learns a vector of \(d\) dimensions, where \(d\) is the number of channel (I suppose they mean the dimension of input vetor of RWKV). It then multiply the relative position, makes the offset is expressed as,

\begin{equation} \label{eq:rwkv-attn-offset} w_{t,i}=-(t-i)w, \end{equation}

where \(w \in ({R_{\geqslant 0}}^{d})\). This means the learned parameter in the offset is a vector rather than a matrix in AFT.

Background

They presented the RNNs, especially LSTMs.

The RWKV seems to be inspired by the Attention Free Transformer (AFT), where learned pair-wise position biases in attention weights are learned. In this way, AFT skip the dot-product and does not need the query.

The RWKV used the same mechanism of AFT, but instead of linear basise, a decayed bais is used.

Details in paper

Some background from AFT. In AFT, there are still Query, Key and Value matrices. Instead of using dot-product between queries and keys, it learns an offset matrix and add an offset scaler to keys to form the attention weights. The computational advantage then relies on the removal of dot-product, which seems to be not that appealing.

Details in paper

D2

Time mixing and channel mixing

The time mixing is a core claim of the RWKV and it is not clear what is it and how it is achieved and what are the reasonings behind from D1.

The recurrence is formulated both as a linear interpolation between the current input and the input at the previous time step (a technique we refer to as time-shift mixing or token shift, indicated by the diagonal lines in Fig. 3)

This suggests that all of the \(\mathbf{k}\), \(\mathbf{v}\) are interpolated by current time step and previous time step. But there is not state.

Figure in paper

D3

The time-mixing block equation is as below,

This \(wkv_{t}\) is summation like attention mechanism. Each receptance \(r_{t}\), key \(k_{t}\) is interpolation of current time step and previous time step.

Qeustion:

  1. what is \(u\) in equation 14?
  2. Why are equation 16-18 called channel-mixing ?
  3. What are the intuition behind the time-mixing and followed by channel-mixing?
  4. How to make them recursively running during inference?
answer in paper Q4:
An interesting perspective in paper, time-mixing block as an RNN cell.


D5

Questions:

D6

Try some codes.

Both following are good places to start with https://ben.bolte.cc/rwkv-model, https://johanwind.github.io/2023/03/23/rwkv_details.html.

The one from Ben’s blog can be found on his github page https://github.com/codekansas/rwkv.