Hugging Face Trainer
MemScale integrates with the Hugging Face Trainer directly — you pass the
Trainer to wrap() and keep calling .train() as usual.
There is no apply_all_optimizations() function. Earlier drafts mentioned
one; it does not exist in MemScale. The integration is just wrap().
Wrapping a Trainer
import memscale
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
args = TrainingArguments(
output_dir="./out",
per_device_train_batch_size=4,
num_train_epochs=1,
)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
# one line — wrap the Trainer before training
trainer = memscale.wrap(trainer)
trainer.train() # VRAM optimized automaticallywrap() detects that it received a Trainer (it has both a .model and a
.train method), prepares the inner model, and attaches the optimization
hooks. The Trainer API is otherwise untouched.
With a custom config
from memscale import wrap, Config, OptimizationMode
config = Config(mode=OptimizationMode.AGGRESSIVE)
trainer = wrap(trainer, config)
trainer.train()Autoregressive models and use_cache
For HF autoregressive models (GPT-2, Llama, Mistral, T5, BART, …), MemScale
automatically sets config.use_cache = False on the model when gradient
checkpointing is enabled. The KV-cache and activation checkpointing are
incompatible — the cache changes activation shapes, which breaks the
backward recompute.
MemScale handles this for you and logs an info message when it does. No action is needed on your part; it only affects training, not a later inference pass.
Throughput reporting
Since v1.1.0, the Hugging Face integration reports real samples/sec
throughput (it was previously hard-coded to 0.0). You will see accurate
throughput in the standard Trainer logs.
Tips
- Pass a representative
sample_inputtowrap()for more accurate profiling — seewrap(). - If you write your own training loop instead of using
Trainer, see the PyTorch Training guide.