Transformers have revolutionized natural language processing and continue to push the boundaries of what's possible in AI. However, as these models grow in size and capability, they face increasing challenges in handling long sequences during training. The memory demands for processing extended contexts can quickly overwhelm even high-end GPUs, limiting our ability to train on lengthy inputs that are critical for tasks like long-form document understanding, extended conversations, and multimodal reasoning.
To grasp where these challenges come from, let's simplify the architecture of a transformer model like Llama. Essentially, Llama is built from a stack of identical units called transformer blocks. By focusing on a single transformer block, we can understand the computations involved throughout the entire model.
The diagram above illustrates the overall architecture of llama and a transformer block. A transformer block consists of three main components:.
For simplicity, we'll ignore some additional components like Layer Normalization, masking, dropout, and residual connections. Instead, we'll focus on the tensor shapes and sizes involved in the computations, which are crucial for understanding the model's memory usage and computational requirements.
Different versions of Llama have varying sizes and capabilities. Here's a table summarizing some key parameters:
Feature | LLAMA2-7B | LLAMA3-8B | LLAMA3.1-8B |
---|---|---|---|
Layers | 32 | 32 | 32 |
Hidden Size (d ) |
4096 | 4096 | 4096 |
Sequence Length (s ) |
4096 | 8192 | 131072 |
MLP Intermediate Size (I ) |
11008 | 14336 | 14336 |
Vocabulary Size (V ) |
32000 | 128256 | 128256 |
Observations:
I
),
increasing from 11,008 to 14,336. This allows the model to capture more complex patterns but increases
computational demands.V
) has grown from 32,000 to 128,256,
accommodating more words and subwords but also increasing the size of certain tensors.s
), up to 131,072
tokens. While this enables the model to process longer inputs, it significantly increases the computational
load.Understanding the shapes of tensors (multi-dimensional arrays) in the model helps us identify where computational challenges arise.
b
): Number of sequences processed in parallel (e.g.,
b = 1
).s
): Length of each input sequence (e.g.,
s = 8192
or 131072
).d
): Size of the hidden state vector in the model (e.g.,
d = 4096
).I
): Size of the intermediate vector in the MLP (e.g.,
I = 14336
).V
): Number of unique tokens the model can predict (e.g.,
V = 128256
).X
): Shape (b, s, d)
Q
): Result of X
projected by
WQ
, shape (b, s, d)
K
): Result of X
projected by
WK
, shape (b, s, d)
V
): Result of X
projected by
WV
, shape (b, s, d)
XWup
: Shape (b, s, I)
XWgate
: Shape (b, s, I)
(b, s, d)
.logit
): Result of X
projected by
WLM-HEAD
, shape (b, s, V)
(b, s, V)
. With large
s
and V
, this tensor becomes enormous. s = 131072
and V = 128256
, the tensor requires over 33
GB of memory just to store the logits.XWup
and
XWgate
) have shapes of (b, s, I)
, also consuming significant
memory.s = 131072
and I = 14336
, the tensor requires over 7.5G
GB of memory just to store the intermediate.X
, Q
, K
, and V
have shapes
(b, s, d)
. With d
being smaller than V
and sometimes
I
, these tensors are relatively smaller in size of 1GB. QK^K
resulting in a tensor of shape (b, s, s, head)
, where
h
is the number of attention heads. or example, with a sequence length of 131072 tokens and 16 heads, the attention tensor would
have dimensions (1, 131072 , 131072 , 16)
, which requires 550G
GB.Q
with shapes
(b, s, d)
in size of 1GB. That open the door of optimizing the other large tensorsTo handle the computational challenges posed by large tensor sizes, especially with long sequences and large vocabularies, we can adopt the following strategy:
M
smaller chunks or
mini-sequences of length s/M
. (b, s, I)
(b, s/M, I)
(b, s, V)
(b, s/M, V)
M
, we can reduce the memory required for these tensors by a
factor of M
.M
, the intermediate memory usage becomes manageable.Thankfully, some smart people have already discovered the approach similarity as mini-batch training, which is deployed as gradient accumulation on multiple library. We can try implementing this technique on sequence level and work out the necessary details.
Congratulations! We’ve essentially know the key idea of Mini-Sequence Transformers [NeurIPS’24] paper. Additionally, I recommend checking out this awesome repo of Mini-Sequence Transformers.
Theorem: Sequence Extension
Mini-Sequence Transformers (MST) can reduce intermediate memory usage by a factor of M
. Working together with gradient checkpoint The total memory occupation becomes:
$$\text{Memory} = W_{\text{mem}} + S \times \left( \dfrac{I_{\text{mem}}}{M} + \sqrt{L} \times A_{\text{mem}} \right),$$
where:
Wmem
is the memory occupied by the model weights.S
is the sequence length.Imem
is the memory per token for intermediate computations in the MLP and LM-Head.L
is the number of layers.Amem
is the memory per token per layer for attention.
For a GPU with maximum memory Mmax
, the maximum sequence length is:
$$S_{\text{max}} = \dfrac{M_{\text{max}} - W_{\text{mem}}}{\dfrac{I_{\text{mem}}}{M} + \sqrt{L} \times A_{\text{mem}}},$$
Going a step further, when we combine it with CPU offload technology:
$$S_{\text{max}} = \dfrac{M_{\text{max}} - W_{\text{mem}}}{\dfrac{I_{\text{mem}}}{M} + A_{\text{mem}}},$$
which is significantly longer than the standard implementation, where the maximum sequence length is:
$$S_{\text{max}} = \dfrac{M_{\text{max}} - W_{\text{mem}}}{I_{\text{mem}} + L \times A_{\text{mem}}}.$$
This demonstrates that by reducing the intermediate memory with MST, we can train models with much longer sequences within the same memory constraints.
We have developed methods and libraries to facilitate this approach. By installling the mini-sequence wrapper pip install -U https://github.com/wdlctc/mini-s
, you can efficiently train large transformer models with long sequences.
from minis.mini_sequence import minisequence
model = minisequence(model)
Understanding the tensor shapes and sizes in transformer models is crucial for identifying computational challenges and optimizing training. By focusing on how the dimensions of these tensors impact memory usage, especially with large sequence lengths and vocabulary sizes, we can develop strategies to make training more efficient.
Partitioning input sequences and processing them as mini-sequences is an effective way to reduce memory usage without sacrificing performance. This approach enables the training of large transformer models on long sequences using available hardware, paving the way for advancements in natural language understanding and generation tasks.
Note: For those interested in implementing these techniques, tools like the Mini-Sequence Transformers library are available and can be integrated into your projects to facilitate efficient training on long sequences.