Optimizing Intermediate Values in Transformers

Transformers have revolutionized natural language processing and continue to push the boundaries of what's possible in AI. However, as these models grow in size and capability, they face increasing challenges in handling long sequences during training. The memory demands for processing extended contexts can quickly overwhelm even high-end GPUs, limiting our ability to train on lengthy inputs that are critical for tasks like long-form document understanding, extended conversations, and multimodal reasoning.

Understanding the Transformer Architecture

To grasp where these challenges come from, let's simplify the architecture of a transformer model like Llama. Essentially, Llama is built from a stack of identical units called transformer blocks. By focusing on a single transformer block, we can understand the computations involved throughout the entire model.

transformer Image

The diagram above illustrates the overall architecture of llama and a transformer block. A transformer block consists of three main components:.

  1. Self-Attention: This mechanism allows the model to weigh the importance of different tokens (words or subwords) in the input sequence when processing each token.
  2. Multilayer Perceptron (MLP): A feedforward neural network that processes the output from the self-attention layer to capture complex patterns.
  3. Language Modeling Head (LM-Head): A final linear layer that maps the processed information to a probability distribution over the vocabulary, helping predict the next token in a sequence.

For simplicity, we'll ignore some additional components like Layer Normalization, masking, dropout, and residual connections. Instead, we'll focus on the tensor shapes and sizes involved in the computations, which are crucial for understanding the model's memory usage and computational requirements.

Key Parameters of Llama Models

Different versions of Llama have varying sizes and capabilities. Here's a table summarizing some key parameters:

Feature LLAMA2-7B LLAMA3-8B LLAMA3.1-8B
Layers 32 32 32
Hidden Size (d) 4096 4096 4096
Sequence Length (s) 4096 8192 131072
MLP Intermediate Size (I) 11008 14336 14336
Vocabulary Size (V) 32000 128256 128256

Observations:

Tensor Shapes and Their Impact on Computation

Understanding the shapes of tensors (multi-dimensional arrays) in the model helps us identify where computational challenges arise.

Scalars and Input Shapes

Key Tensors in the Transformer Block

  1. Input Tensor (X): Shape (b, s, d)
    • Represents the input embeddings for the sequence.
  2. Self-Attention Computations:
    • Queries (Q): Result of X projected by WQ, shape (b, s, d)
    • Keys (K): Result of X projected by WK, shape (b, s, d)
    • Values (V): Result of X projected by WV, shape (b, s, d)
  3. MLP Computations:
    • Intermediate Tensors:
      • XWup: Shape (b, s, I)
      • XWgate: Shape (b, s, I)
    • These are combined and transformed back to shape (b, s, d).
  4. LM-Head Output:
    • Logits (logit): Result of X projected by WLM-HEAD, shape (b, s, V)

Challenges with Large Tensors

Mini-Sequence Processing to Optimize Memory Usage

To handle the computational challenges posed by large tensor sizes, especially with long sequences and large vocabularies, we can adopt the following strategy:

Partitioning Input Sequences

Thankfully, some smart people have already discovered the approach similarity as mini-batch training, which is deployed as gradient accumulation on multiple library. We can try implementing this technique on sequence level and work out the necessary details.

Congratulations! We’ve essentially know the key idea of Mini-Sequence Transformers [NeurIPS’24] paper. Additionally, I recommend checking out this awesome repo of Mini-Sequence Transformers.

mst Image

Sequence Extension

Let's dive into the details of how optimized intermediate memory usage enables training on long sequences.

Theorem: Sequence Extension

Mini-Sequence Transformers (MST) can reduce intermediate memory usage by a factor of M. Working together with gradient checkpoint The total memory occupation becomes:

$$\text{Memory} = W_{\text{mem}} + S \times \left( \dfrac{I_{\text{mem}}}{M} + \sqrt{L} \times A_{\text{mem}} \right),$$

where:

For a GPU with maximum memory Mmax, the maximum sequence length is:

$$S_{\text{max}} = \dfrac{M_{\text{max}} - W_{\text{mem}}}{\dfrac{I_{\text{mem}}}{M} + \sqrt{L} \times A_{\text{mem}}},$$

Going a step further, when we combine it with CPU offload technology:

$$S_{\text{max}} = \dfrac{M_{\text{max}} - W_{\text{mem}}}{\dfrac{I_{\text{mem}}}{M} + A_{\text{mem}}},$$

which is significantly longer than the standard implementation, where the maximum sequence length is:

$$S_{\text{max}} = \dfrac{M_{\text{max}} - W_{\text{mem}}}{I_{\text{mem}} + L \times A_{\text{mem}}}.$$

This demonstrates that by reducing the intermediate memory with MST, we can train models with much longer sequences within the same memory constraints.

Example Implementation

We have developed methods and libraries to facilitate this approach. By installling the mini-sequence wrapper pip install -U https://github.com/wdlctc/mini-s, you can efficiently train large transformer models with long sequences.


from minis.mini_sequence import minisequence
model = minisequence(model)

Key Takeaway

Conclusion

Understanding the tensor shapes and sizes in transformer models is crucial for identifying computational challenges and optimizing training. By focusing on how the dimensions of these tensors impact memory usage, especially with large sequence lengths and vocabulary sizes, we can develop strategies to make training more efficient.

Partitioning input sequences and processing them as mini-sequences is an effective way to reduce memory usage without sacrificing performance. This approach enables the training of large transformer models on long sequences using available hardware, paving the way for advancements in natural language understanding and generation tasks.


Note: For those interested in implementing these techniques, tools like the Mini-Sequence Transformers library are available and can be integrated into your projects to facilitate efficient training on long sequences.