Getting StartedQuick Start

Quick Start

This page takes you from a fresh install to an optimized training run in about five minutes.

1. Install

pip install memscale

See Installation for PyTorch and CUDA requirements.

2. Wrap your model

The whole integration is one line. Import memscale, then pass your model to wrap():

import memscale
import torch
from transformers import AutoModelForCausalLM
 
model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
 
# full optimization stack — matches the benchmark numbers below
model, optimizer = memscale.apply_all_optimizations(model, optimizer)
 
# train as usual
for batch in dataloader:
    loss = model(**batch).loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

wrap() profiles the model, builds a per-layer optimization plan, and attaches hooks. Your training loop does not change.

3. wrap() vs apply_all_optimizations

MemScale gives you two entry points:

  • memscale.wrap(model) — the simplest call. Applies checkpointing + CPU offload only (balanced mode). Lightweight, but the reduction is modest.
  • memscale.apply_all_optimizations(model, optimizer) — the full stack: checkpointing + offload plus BF16 mixed precision and 8-bit Adam. This is what produces the headline reductions in the table below, and what the benchmark suite measures.

The benchmark numbers below come from apply_all_optimizations. A plain wrap() call alone will not reach them — it omits BF16 and 8-bit Adam, which do most of the work. Use apply_all_optimizations when you want maximum savings.

Before / after VRAM

These are measured peak training VRAM figures on an NVIDIA RTX 3090 (24 GB), from the project’s benchmark suite, using the full apply_all_optimizations stack:

ModelParamsBatch × SeqBaselineMemScaleReduction
BERT-Base110M16 × 1283.14 GB0.84 GB73.1%
GPT-2 Medium355M4 × 51210.87 GB2.61 GB76.0%
GPT-2 Large774M2 × 51214.87 GB4.68 GB68.5%
GPT-2 XL1.5B1 × 512OOM9.25 GBenables training

GPT-2 XL does not fit in 24 GB at all without MemScale; with it, the run trains in 9.25 GB.

Next steps