wrap()

Wrap a model or trainer with MemScale optimizations. This is the primary entry point.

Signature

memscale.wrap(
    target,
    config=None,
    sample_input=None,
    async_offload=None,
    async_config=None,
)

Parameters

ParameterTypeDefaultDescription
targetnn.Module or HF TrainerThe model or Hugging Face Trainer to optimize.
configConfigNoneOptimization settings. None uses Config() (balanced mode).
sample_inputtorch.TensorNoneA representative input tensor. Improves profiling accuracy.
async_offloadboolNoneEnable the experimental tier-aware async CPU offload engine. None falls back to the MEMSCALE_ASYNC_OFFLOAD env var, then to the config value.
async_configAsyncOffloadConfigNonePower-user tuning for the async offload path (pinned-pool size, prefetch depth, debug).
⚠️

wrap() does not take an optimizer argument. To optimize a training loop together with its optimizer, use optimize() instead.

Returns

The same target object, with optimization hooks attached. A wrapped nn.Module is returned as an nn.Module; a wrapped Trainer as a Trainer. You continue to use the object exactly as before.

What it accepts

wrap() detects what you passed:

  • a Hugging Face Trainer — the inner model is prepared and wrapped;
  • a raw nn.Module (including an HF model such as GPT2LMHeadModel).

Anything else raises TypeError.

Examples

Raw module

import memscale
from transformers import AutoModelForCausalLM
 
model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
model = memscale.wrap(model)

With a custom config

from memscale import wrap, Config, OptimizationMode
 
config = Config(mode=OptimizationMode.AGGRESSIVE)
model = wrap(model, config)

With a sample input for accurate profiling

import torch
 
sample = torch.randint(0, 50257, (4, 512))   # batch 4, seq 512
model = memscale.wrap(model, sample_input=sample)

Hugging Face Trainer

from transformers import Trainer
 
trainer = Trainer(model=model, args=args, ...)
trainer = memscale.wrap(trainer)
trainer.train()

Notes

  • For HF autoregressive models, wrap() automatically sets config.use_cache = False when checkpointing is enabled, to avoid a recompute shape mismatch. It logs an info message when it does this.
  • To reverse wrap(), call detach().