Core ConceptsOverview

Overview

MemScale turns a plain PyTorch model into a memory-optimized one through a short pipeline. This page describes that pipeline end to end.

The big picture

User Model


  wrap()  ──►  Profiler  ──►  Decision Engine  ──►  Execution Plan  ──►  Executor
                  │                │                                       │
            detect hardware   per-layer plan                        attach hooks
            profile graph     (rule-based)                                 │

                                                                Optimized Training

When you call wrap(), MemScale runs these stages in order.

The three layers

1. User API

The surface you interact with: wrap(), optimize(), and detach(), configured by a Config object. Most users never go below this layer.

2. Decision Engine

Given a profiled model and the detected hardware, the decision engine produces an execution plan — a mapping of each layer to the techniques it should receive. The engine is rule-based and deterministic: identical inputs always yield the same plan. See Decision Engine.

In v1.2 an optional ML policy runs before this engine to pick the high-level strategy, but the per-layer expansion remains the same deterministic rule engine.

3. Executor

The executor takes the execution plan and attaches the actual hooks to the model — checkpoint wrappers, offload hooks, tiling, precision casts. It also stores a handle on the model (model._memscale_executor) so that detach() can cleanly reverse everything.

Profiling

Before the decision engine runs, a profiler:

  • detects hardware — GPU count, total VRAM, compute capability;
  • profiles the model graph — layer structure and per-layer memory cost, using torch.fx static analysis where possible and falling back to empirical runtime profiling otherwise.

Passing a sample_input to wrap() improves profiling accuracy.

What the executor applies

The techniques the plan can assign are covered in detail under Techniques: gradient checkpointing, mixed precision, CPU offloading, activation tiling, and 8-bit optimizers.