If you want to try your hand at fine-tuning an LLM (Large Language Model): one of the first things you’re going to need to know is “will it fit on my GPU”.

The pre-eminent guide to estimating (VRAM) memory requirements is Transformer Math 101. It bears mentioning, though, that its heuristics are written in the context of frameworks such as GPT-NeoX and Megatron-DeepSpeed. These are specialist frameworks for training at scale, and are not what a beginner would reach for.

Are Transformer Math’s memory estimates applicable to idiomatic PyTorch, without additional frameworks? We put it to the test.

In brief, we found:

  • Memory costs measured on small models extrapolate well to larger models
  • PyTorch AMP costs are similar to the frameworks discussed in Transformer Math, but (see next)
  • DeepSpeed ZeRO’s fused optimiser saves 2 bytes/param on mixed-precision over PyTorch AMP
  • Gradient accumulation can be ~free, but costs more if used in concert with AMP or DDP
  • AMP usually costs ~2 bytes/param. Cost increases when gradient accumulation is enabled, or becomes ~free if used in concert with DDP
  • DDP usually costs ~4 bytes/param, but becomes cheaper if used in concert with AMP
  • DDP can be made 2.5 bytes/param cheaper by enabling gradient_as_bucket_view
  • Only the smallest LLMs can be trained under DDP with realistic settings, even on 80GiB GPUs.
  • Training large LLMs requires more advanced distributed training techniques, to shard memory costs over multiple GPUs
  • Parameter-efficient finetuning is a more realistic option for the GPU-poor

Read on to see our reasoning and evidence.

Click (yes, like that. there's a few of these, so keep an eye out)

to reveal more detail. We hide deep-dives in these collapsible sections to keep the article snappy.

Experimental setup

We built a simple training loop (mem_llm.py). It loads a model via HF (HuggingFace) transformers, and trains said model to complete random sequences. By running this training loop for a few steps, then taking a measurement: we can determine our peak memory usage.

We will evaluate the cost of three common techniques provided by an ML accelerator (explanations to follow later):

In our trainer: this functionality is built using PyTorch idioms. Our implementation is similar to how HF accelerate operates in its most basic configurations (i.e. when it is not configured to delegate to other accelerators such as DeepSpeed or Megatron).

What we expect to see

Before we dive into measurements: let’s estimate how much memory we’ll need in theory. We will use reasoning similar to Transformer Math.

For non-distributed training,
Total Training Memory = Model + Optimiser + Activations + Gradients

Our experimental setup eliminates optimiser costs, through means of a stateless optimiser (SGD without momentum).

We also minimise activations, by using tiny batch sizes and sequence lengths.At batch-of-1, sequence length 8: we expect activations to be negligible (<100MiB) even on our largest model (Llama 7b), according to Transformer Math's worst-case activations formula, "no recomputation" (i.e. gradient checkpointing disabled).

These measures narrow down the possible sources of overhead, if detected.

Estimated baseline costs in single-precision

First up, let’s establish the terminology we’re going to use:

Precision Datatype(s) Bytes per element
Single float32 (may imply tf32) 4
Half float16, bfloat16 2

The memory costs for pure-float32 training are relatively straightforward:

Model Params 4 bytes/param
Gradients 4 bytes/param
Optimiser states None in our setup*
Activations Negligible in our setup**
Total 8 bytes/param

* It's worth noting that our use of a stateless optimiser is highly unrealistic; typical optimiser costs are substantial. Enlisting an uncontroversial optimiser such as 32-bit AdamW would cost an additional 8 bytes/param.
** Likewise, activations would normally be a substantial cost, easily capable of growing bigger than all other costs combined. Even more so in pure-float32 training, where our activations will be 32-bit. By comparison: mixed-precision training enables activations to be smaller (employing the smaller datatype used for reduced-precision compute).

We don’t dwell on this baseline, because it is a bit of a strawman. Floating-point operations are slow in 32-bit. This can be alleviated somewhat via the TensorFloat-32 optimization, which enables some operations to utilize tensor cores . But ultimately, wider formats incur more I/O costs, and lack accelerated kernels for essential operations such as Flash Attention.

For performance reasons, and to reduce the cost of activations: it is standard to train in mixed-precision.

Estimated baseline costs in mixed-precision

Mixed-precision costs are framework-specific. We will interpret the costs for the frameworks considered in Transformer Math, and also make an educated guess at the costs of PyTorch’s built-in mixed-precision implementation, AMP.

Transformer Math was validated against GPT-NeoX (circa April 2023), and was cross-checked against Megatron-DeepSpeed (against which its predictions held, for everything except activation checkpointing).
Its estimates are expected to be predictive to within ±10%, including on other accelerators such as Megatron-LM, PyTorch FSDP, and ColossalAI.
Thanks go to Quentin Anthony at EleutherAI for confirming the above.

GPT-NeoX implements mixed-precision via DeepSpeed ZeRO. In both DeepSpeed ZeRO and Megatron-DS: the model is held in half-precision, whilst the optimiser takes a float32 copy of the parameters.

PyTorch AMP (Automated Mixed Precision) implements mixed-precision moreorless the opposite way around. The model is held in float32, and AMP manages a half-precision copy.

Our understanding is as follows:

Megatron-DS, DS ZeRO Pytorch AMP
Model Params 2 bytes/param 4 bytes/param
Gradients 2 bytes/param 4 bytes/param
Compute Copy Params None 2 bytes/param*
Gradients 0 bytes/param**
Optimiser Copy Params 4 bytes/param None
Gradients 4 bytes/param*** if unfused
Optimiser states None in our setup
Activations Negligible in our setup
Total 8 bytes/param (fused), or
12 bytes/param (unfused)
10 bytes/param

*We believe that a compute copy is not taken for parameters whose layers autocast to single-precision (for example Embedding and LayerNorm – investigated in this minimal reproducer).
LayerNorm params are insignificant (comprising 0.00% of llama-7b’s params), but Embedding can be one of the biggest layers in an LLM. Embedding becomes a smaller proportion of parameters as a model is scaled up, so we can disregard it as far as the overall pattern is concerned.
Consider this an upper bound, because the lifetime of a compute copy is a performance choice: it’s needed during both the forward and backward passes, so the fastest approach is to set aside enough memory to maintain permanent compute copies for all layers. But at minimum: one layer at a time can be recreated from the model’s float32 parameters, then freed after computation (and this work can be re-done during the backwards pass).
**In PyTorch AMP: we believe the half-precision gradients are never materialized. They are instead reduced into the model’s float32 gradients immediately. Looking at memory snapshots: we see that the backwards pass for a Linear layer allocates a float32-sized buffer but no corresponding half-precision-sized buffer.
***Transformer Math does not mention a "4 bytes/param master gradients" cost. We think this assumes a fused optimiser, such as DeepSpeed provides. We believe we see evidence of a master gradient cost in Megatron-DS, and in contemporary DeepSpeed’s unfused optimiser.

The above mixed-precision costs assume that other features such as gradient accumulation and distributed training are disabled. We will now consider the costs of those features in isolation.

Estimated cost of gradient accumulation

Gradient accumulation is a technique which computes the gradients for a minibatch, one microbatch at a time, summing each result (divided by the number of microbatches). This is typically done to save memory (at the expense of wall time), for a given minibatch size.
In distributed training, gradient accumulation can also be exploited to reduce the frequency with which synchronization of gradients is required, reducing the proportion of training time spent on communication overhead.

When running in mixed-precision: we expect gradient accumulation to cost:

Megatron-DS, DS ZeRO PyTorch AMP
0 or 4 bytes/param* 0 bytes/param**

*Deepspeed offers two gradient accumulation modes, based on whether you accumulate and reduce gradients using the same datatype. If the datatypes match: you can use the same buffer for both, saving memory. This is an unrealistic/legacy option though; it is better to reduce in half-precision (to send less data over the network in distributed training) and accumulate in float32 (accumulators benefit from increased precision). This requires the allocation of an additional float32 buffer, costing 4 bytes/param.
**In PyTorch AMP: we already paid for a buffer that can be used for gradient accumulation. Our backwards pass and optimisation step are not fused together; we store float32 gradients, then we step our optimiser. The backward pass accumulates the half-precision compute gradients into this float32 gradient, whether we do 1 microstep or many.

Estimated cost of distributed training

There are many ways to distribute training. Among the simplest is DDP (Distributed Data Parallel), which replicates the model across multiple GPUs. We wish to measure its costs here, because its ease-of-use makes it a good fit for hobbyist or small-scale training.

DDP doesn’t reduce the size of the model or optimiser states. It enables us to achieve bigger minibatch sizes, or give smaller microbatches to each GPU, reducing the per-GPU cost of activations. The experiment we’re doing here has already minimized the cost of activations. Hence we don’t expect DDP to save us memory, but we are interested in measuring the overhead incurred by the extra communication it necessitates.

We expect PyTorch DDP to cost 4 bytes/param, as its distributed Reducer creates gradient buckets for each parameter. This cost can be reduced by enabling gradient_as_bucket_view.

The frameworks studied in Transformer Math can be distributed in more exotic ways than DDP, capable of sharding the model and optimiser states. DeepSpeed ZeRO allocates a ~477MiB bucket for allreduce operations, a fixed cost that should not scale with model size.

What we observed experimentally

We trained with batch-of-1, short sequences, and a stateless optimiser for 2 steps.
Where gradient accumulation was used: we ran each step for 3 microsteps.

For 7b models, we used naïve model parallelism to vertically-shard costs over two 48GiB GPUs. This type of parallelism is non-performant and therefore unrealistic, but serves as a good way to simulate having one larger GPU.

Note also that when Nvidia markets the capacity of a GPU, they use the measure ‘GB’ but they actually mean GiB.

We were not able to measure DDP costs 7b models, as we had no remaining GPUs over which to replicate the training.

Memory cost (MiB)

Arch Model Param (b) DDP Off On
AMP Off On Off On
GA Off On Off On Off On Off On
Total memory used + reserved (GiB)
GPT-NeoX pythia-1.4b 1.4 10.7 11.1 12.7 13.8 16.1 16.6 16.3 19.5
pythia-2.8b 2.8 21.0 21.5 25.0 27.4 31.8 32.1 32.4 37.6
pythia-6.9b 6.9 51.4 53.1 61.7 67.8
Llama openllama-3b 3.4 25.9 26.3 29.8 34.6 39.5 39.5 40.2 >48
llama-2-7b 6.7 50.5 51.6 57.4 67.0

We can see that almost none of these training configurations fit on an Nvidia RTX 4090 (24GiB), the largest consumer Nvidia GPU at the time of writing.
It’s even a tight fit on a lot of server Nvidia GPUs (e.g. 40GiB and 48GiB). Bear in mind how these training configurations constitute a lower bound (we aren’t paying for optimiser state, and our activations are unrealistically small).

We see that enabling all of these training techniques is punitive, doubling our memory costs.

It’s clear why Nvidia’s top-end 80GiB GPUs are desirable, and why there is interest in training on AMD’s 128GiB GPUs. We also see that realistically, the training of large models must rely upon the sharding of memory costs across GPUs.

Let’s now try to understand these costs as a function of model size.

Memory cost (bytes/param)

Here we divide the memory usage by the number of model parameters.

Arch Model Param (b) DDP Off On
AMP Off On Off On
GA Off On Off On Off On Off On
bytes/param
GPT-NeoX pythia-1.4b 1.4 8.1 8.4 9.6 10.5 12.2 12.6 12.3 14.8
pythia-2.8b 2.8 8.1 8.3 9.7 10.6 12.3 12.4 12.5 14.5
pythia-6.9b 6.9 8.0 8.3 9.7 10.6
Llama openllama-3b 3.4 8.1 8.2 9.3 10.9 12.4 12.4 12.6 >15.0
llama-2-7b 6.7 8.0 8.2 9.1 10.7

Remarkably, some very regular patterns emerge, scaling moreorless linearly with model size.

Our (strawman) baseline of “everything in float32” is indeed 8 bytes/param. This baseline should be low-overhead; it will incur less fragmentation, as it does not need to allocate as many temporary buffers as the other configurations.

Our estimate that “mixed-precision should cost 10 bytes/param” looks close (it’s actually slightly cheaper), but the theory “DDP costs 4 bytes/param more” only holds true when gradient accumulation is enabled. Something to dig into.

Let’s visualise these datapoints as relative overheads compared to our float32 baseline.

Memory cost (bytes/param), relative to float32 baseline

Arch Model Param (b) DDP Off On
AMP Off On Off On
GA Off On Off On Off On Off On
Relative overhead (bytes/param)
GPT-NeoX pythia-1.4b 1.4 0.00 0.29 1.52 2.34 4.08 4.45 4.22 6.65
pythia-2.8b 2.8 0.00 0.18 1.56 2.45 4.19 4.27 4.39 6.41
pythia-6.9b 6.9 0.00 0.28 1.62 2.57
Llama openllama-3b 3.4 0.00 0.12 1.20 2.73 4.25 4.26 4.47 >6.91
llama-2-7b 6.7 0.00 0.18 1.10 2.63

Some results here match our expectations, but there are some surprises also.

Gradient accumulation on its own is almost free (as expected), but it has interesting interactions depending on which other features are enabled. Using it in concert with mixed-precision costs more than the sum of either feature’s cost in isolation.

Mixed-precision on its own is cheaper than the 2 bytes/param that we predicted. Perhaps PyTorch AMP found situations where it could reclaim freed temporary buffers. This could be possible if (for example) it opted to recreate half-precision compute weights rather than keep them around.

We do expect mixed-precision to cost slightly below 2 bytes/param due to embeddings and norms’ not requiring half-precision copies (they are computed in float32). However, the discrepancy we’re seeing seems too large to be closed by that theory. Discounting a half-precision copy of pythia 6.9b’s embedding layer would only explain 0.06 bytes/param of this undershoot (and another 0.00 bytes/param if we consider norm params).

DDP in isolation costs about 4 bytes/param, as expected. This can be reduced by enabling gradient_as_bucket_view. We conducted a limited exploration of its benefits:

Arch Model DDP On Memory saved
Mixed Prec On
Grad Acc On
Bucket view Off On Off On
Mem (GiB) bytes/param GiB b/param
GPT-NeoX pythia-1.4b 19.5 16.1 14.77 12.20 3.38 2.57
pythia-2.8b 37.8 31.4 14.61 12.16 6.33 2.45

The pattern appears to scale with model size; enabling gradient_as_bucket_view seems to save about 2.5 bytes/param on the cost of DDP. This is surprising. From the description ("This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size"): it sounded like we could expect the overhead to change from a per-parameter cost to a fixed overhead (perhaps structured as a single allreduce buffer, like in GPT-NeoX). Which would mean converging on a saving of 4 bytes/param as model size increases. Such a trend is not evidenced by our measurements.

DDP, AMP and GA exhibit interesting interactions with each other. Let’s determine “how much would it cost to enable X, given Y and/or Z are enabled”.

Memory cost (bytes/param) of enabling mixed-precision

Given a starting point of “DDP and GA are enabled/disabled”: how much does it cost to enable mixed-precision?

Arch Model DDP Off On
GA Off On Off On
AMP cost (bytes/param)
GPT-NeoX pythia-1.4b 1.52 2.05 0.14 2.20
pythia-2.8b 1.56 2.27 0.21 2.13
pythia-6.9b 1.62 2.29
Llama openllama-3b 1.20 2.61 0.22 OOM
llama-2-7b 1.10 2.44

If you’re already paying for DDP: mixed-precision can be added essentially for free. This indicates that some memory re-use becomes possible. Perhaps because mixed-precision and DDP require memory at different times: half-precision compute copies can reside in the DDP allreduce buffer, which is needed only after the mixed-precision results have already been computed.

When gradient accumulation is off: mixed-precision costs less than the 2 bytes/param upper bound we predicted.

This could be because a layer’s float32 gradients and its half-precision compute copy don't need to exist at the same time. At least, if we meet two conditions:

  • Conclude training steps by deleting the model’s gradients (Optimizer.zero_grad(set_to_none=True))
  • Disable gradient accumulation

These conditions are satisfied by our experimental setup. Consequently, there is an opportunity for memory re-use.

Perhaps when the backward pass frees two layers’ half-precision compute copies: the memory is used to pay for one layer’s float32 gradients. This would halve the 4 bytes/param cost we reserved for float32 gradients (as only half of our layers would need to allocate new memory). As compelling as this theory is: we would expect a 0bytes/param overhead for enabling mixed-precision if this were the case (as mixed-precision would subsidise a cost we were already paying for). Moreover it makes assumptions that two pages could be merged into a bigger contiguous page; we are not convinced this happens.

An easier explanation could be that the lifetime of a compute copy is limited, and some layers recreate their compute copies from the float32 model parameters.

When gradient accumulation is on: enabling mixed-precision costs more than the 2 bytes/param that we thought it could cost in theory. This could indicate that less memory re-use is possible, or that more fragmentation occurs.

Memory cost (bytes/param) of enabling Gradient accumulation

Given a starting point of “DDP and AMP are enabled/disabled”: how much does it cost to enable gradient accumulation?

Arch Model DDP Off On
AMP Off On Off On
GA cost (bytes/param)
GPT-NeoX pythia-1.4b 0.29 0.82 0.37 2.43
pythia-2.8b 0.18 0.89 0.09 2.01
pythia-6.9b 0.28 0.95
Llama openllama-3b 0.12 1.53 0.01 OOM
llama-2-7b 0.18 1.53

Both DDP and AMP have adverse interactions with gradient accumulation. Enabling either will increase the cost, but the cost of enabling both is greater than the sum of their individual overheads.

This is an overhead we encountered previously when using HF accelerate libraries, but now we see that it reproduces in pure PyTorch.

Memory cost (bytes/param) of enabling DDP

Given a starting point of “AMP and GA are enabled/disabled”: how much does it cost to enable DDP?

Arch Model AMP Off On
GA Off On Off On
DDP cost (bytes/param)
GPT-NeoX pythia-1.4b 4.08 4.16 2.70 4.31
pythia-2.8b 4.19 4.09 2.83 3.95
pythia-6.9b
Llama openllama-3b 4.25 4.14 3.27 >4.18
llama-2-7b

In most cases, it costs 4 bytes/param to enable DDP. There is one synergetic case, whereby if mixed-precision is used without gradient accumulation, then the cost of DDP decreases.

Conclusions

The memory costs we derived for PyTorch AMP aren’t a million miles away from what’s described in Transformer Math. Moreover, these heuristics predicted our empirical results reasonably well.

The training techniques studied here incurred greater overheads when used in combination than when used individually, perhaps due to memory fragmentation. Some combinations appeared to exhibit synergies, perhaps due to re-use of freed temporary memory.

For DDP specifically, we learned that gradient_as_bucket_view can be enabled, to save a huge amount of memory (~2.5 bytes/param).

With all techniques enabled (mixed-precision training, gradient accumulation, DDP): our lower bound is somewhere in the region of 15 bytes/param.

Add a realistic optimiser (32-bit AdamW*) and that increases to 23 bytes/param, or 145GiB for llama 7b. This exceeds the capacity of most GPUs on the market. It could fit on an AMD MI300X 192GB!
*More exotic optimisers exist, with lower memory requirements, such as 8-bit AdamW. 32-bit AdamW is a good place to start if you have enough memory.

It’s clear that beginner-level techniques won’t cut it for fine-tuning LLMs; we need to shard the model parameters, gradients, optimiser states and data over multiple GPUs, trading DDP for more advanced techniques such as PP (pipeline parallelism), FSDP (fully-sharded data parallel), ZeRO (Zero Redundancy Optimizer), and/or TP (tensor parallelism).

Realistic activations for a model such as llama-2 7b could be 1GiB per batch item** (0.16 bytes/param), and a realistic batch size could be 128 or higher. That’s another 128GiB we will only fit by distributing over many GPUs, or by serializing into microbatches via gradient accumulation. **Transformer Math “full recomputation” formula (so this is a lower bound, using gradient checkpointing), for sequence length 4096.

As daunting as this looks: we should spare a thought for those who confront pretraining (training a model from scratch), which entails larger batch sizes, longer training runs, and more failed attempts. It can cost $250k to pretrain a 7b model, and 7b is on the small side for LLMs!

But there is an easier way! For the rest of us, parameter-efficient finetuning methods exist, which drastically lower the compute requirements, making fine-tuning possible on consumer GPUs. We hope to explore these techniques in a future blog post.

Acknowledgements

Thanks to collaborators at EleutherAI — including the authors of Transformer Math 101 — for answering questions during the writing of this article. Thanks also to Tim Dettmers for explaining optimiser costs.

Citing

@article{birch2023llmmem,
  title   = "LLM finetuning memory requirements",
  author  = "Birch, Alex",
  journal = "blog.scottlogic.com",
  year    = "2023",
  url     = "https://blog.scottlogic.com/2023/11/24/llm-mem.html"
}

References

This list is limited to our primary reference (Transformer Math 101) and arXiv references; articles and papers which provided BibTeX records. For a more complete list of our references: please refer to the hyperlinks in this article.

[1] Anthony et al. Transformer Math 101 blog.eleuther.ai 2023
[2] Li et al. PyTorch Distributed: Experiences on Accelerating Data Parallel Training VLDB 2020
[3] Dao et al. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness arXiv preprint arXiv:2205.14135 (2022)
[4] Kingma et al. Adam: A Method for Stochastic Optimization ICLR 2015
[5] Loshchilov et al. Decoupled Weight Decay Regularization ICLR 2019
[6] Dettmers et al. 8-bit Optimizers via Block-wise Quantization ICLR 2022
[7] Huang et al. GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism CVPR 2018
[8] Rajbhandari et al. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models arXiv preprint arXiv:1910.02054 (2019)