Getting StartedQuick Start

Quick Start

This page takes you from a fresh install to an optimized training run in about five minutes.

1. Install

pip install memscale

See Installation for PyTorch and CUDA requirements.

2. Wrap your model

The whole integration is one line. Import memscale, then pass your model to wrap():

import memscale
import torch
from transformers import AutoModelForCausalLM
 
model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
model = memscale.wrap(model)   # default config — balanced mode
 
# train as usual
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
 
for batch in dataloader:
    loss = model(**batch).loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

wrap() profiles the model, builds a per-layer optimization plan, and attaches hooks. Your training loop does not change.

3. What wrap() does

With no Config argument, wrap() uses balanced mode: gradient checkpointing and CPU offloading are enabled, activation tiling is off. See First Optimization to customize this.

Before / after VRAM

These are measured peak training VRAM figures on an NVIDIA RTX 3090 (24 GB), from the project’s benchmark suite:

ModelParamsBatch × SeqBaselineMemScaleReduction
BERT-Base110M16 × 1283.14 GB0.84 GB73.1%
GPT-2 Medium355M4 × 51210.87 GB2.61 GB76.0%
GPT-2 Large774M2 × 51214.87 GB4.68 GB68.5%
GPT-2 XL1.5B1 × 512OOM9.25 GBenables training

GPT-2 XL does not fit in 24 GB at all without MemScale; with it, the run trains in 9.25 GB.

Next steps