wrap()
Wrap a model or trainer with MemScale optimizations. This is the primary entry point.
Signature
memscale.wrap(
target,
config=None,
sample_input=None,
async_offload=None,
async_config=None,
)Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
target | nn.Module or HF Trainer | — | The model or Hugging Face Trainer to optimize. |
config | Config | None | Optimization settings. None uses Config() (balanced mode). |
sample_input | torch.Tensor | None | A representative input tensor. Improves profiling accuracy. |
async_offload | bool | None | Enable the experimental tier-aware async CPU offload engine. None falls back to the MEMSCALE_ASYNC_OFFLOAD env var, then to the config value. |
async_config | AsyncOffloadConfig | None | Power-user tuning for the async offload path (pinned-pool size, prefetch depth, debug). |
⚠️
wrap() does not take an optimizer argument. To optimize a training
loop together with its optimizer, use optimize() instead.
Returns
The same target object, with optimization hooks attached. A wrapped
nn.Module is returned as an nn.Module; a wrapped Trainer as a
Trainer. You continue to use the object exactly as before.
What it accepts
wrap() detects what you passed:
- a Hugging Face
Trainer— the inner model is prepared and wrapped; - a raw
nn.Module(including an HF model such asGPT2LMHeadModel).
Anything else raises TypeError.
Examples
Raw module
import memscale
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
model = memscale.wrap(model)With a custom config
from memscale import wrap, Config, OptimizationMode
config = Config(mode=OptimizationMode.AGGRESSIVE)
model = wrap(model, config)With a sample input for accurate profiling
import torch
sample = torch.randint(0, 50257, (4, 512)) # batch 4, seq 512
model = memscale.wrap(model, sample_input=sample)Hugging Face Trainer
from transformers import Trainer
trainer = Trainer(model=model, args=args, ...)
trainer = memscale.wrap(trainer)
trainer.train()Notes
- For HF autoregressive models,
wrap()automatically setsconfig.use_cache = Falsewhen checkpointing is enabled, to avoid a recompute shape mismatch. It logs an info message when it does this. - To reverse
wrap(), calldetach().