Overview
MemScale turns a plain PyTorch model into a memory-optimized one through a short pipeline. This page describes that pipeline end to end.
The big picture
User Model
│
▼
wrap() ──► Profiler ──► Decision Engine ──► Execution Plan ──► Executor
│ │ │
detect hardware per-layer plan attach hooks
profile graph (rule-based) │
▼
Optimized TrainingWhen you call wrap(), MemScale runs these stages in order.
The three layers
1. User API
The surface you interact with: wrap(),
optimize(), and detach(), configured by a
Config object. Most users never go below this layer.
2. Decision Engine
Given a profiled model and the detected hardware, the decision engine produces an execution plan — a mapping of each layer to the techniques it should receive. The engine is rule-based and deterministic: identical inputs always yield the same plan. See Decision Engine.
In v1.2 an optional ML policy runs before this engine to pick the high-level strategy, but the per-layer expansion remains the same deterministic rule engine.
3. Executor
The executor takes the execution plan and attaches the actual hooks to
the model — checkpoint wrappers, offload hooks, tiling, precision casts. It
also stores a handle on the model (model._memscale_executor) so that
detach() can cleanly reverse everything.
Profiling
Before the decision engine runs, a profiler:
- detects hardware — GPU count, total VRAM, compute capability;
- profiles the model graph — layer structure and per-layer memory cost,
using
torch.fxstatic analysis where possible and falling back to empirical runtime profiling otherwise.
Passing a sample_input to wrap() improves profiling accuracy.
What the executor applies
The techniques the plan can assign are covered in detail under Techniques: gradient checkpointing, mixed precision, CPU offloading, activation tiling, and 8-bit optimizers.