GuidesHugging Face

Hugging Face Trainer

MemScale integrates with the Hugging Face Trainer directly — you pass the Trainer to wrap() and keep calling .train() as usual.

⚠️

There is no apply_all_optimizations() function. Earlier drafts mentioned one; it does not exist in MemScale. The integration is just wrap().

Wrapping a Trainer

import memscale
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
 
model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
 
args = TrainingArguments(
    output_dir="./out",
    per_device_train_batch_size=4,
    num_train_epochs=1,
)
 
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
 
# one line — wrap the Trainer before training
trainer = memscale.wrap(trainer)
 
trainer.train()   # VRAM optimized automatically

wrap() detects that it received a Trainer (it has both a .model and a .train method), prepares the inner model, and attaches the optimization hooks. The Trainer API is otherwise untouched.

With a custom config

from memscale import wrap, Config, OptimizationMode
 
config = Config(mode=OptimizationMode.AGGRESSIVE)
trainer = wrap(trainer, config)
trainer.train()

Autoregressive models and use_cache

For HF autoregressive models (GPT-2, Llama, Mistral, T5, BART, …), MemScale automatically sets config.use_cache = False on the model when gradient checkpointing is enabled. The KV-cache and activation checkpointing are incompatible — the cache changes activation shapes, which breaks the backward recompute.

MemScale handles this for you and logs an info message when it does. No action is needed on your part; it only affects training, not a later inference pass.

Throughput reporting

Since v1.1.0, the Hugging Face integration reports real samples/sec throughput (it was previously hard-coded to 0.0). You will see accurate throughput in the standard Trainer logs.

Tips

  • Pass a representative sample_input to wrap() for more accurate profiling — see wrap().
  • If you write your own training loop instead of using Trainer, see the PyTorch Training guide.