GuidesPyTorch Training

PyTorch Training Loop

When you write your own training loop — no Hugging Face Trainer — there are two ways to apply MemScale. Pick whichever fits your code.

Option A — optimize() context manager

optimize() scopes the optimization to a with block and detaches automatically when the block exits. It is the most natural fit for a hand-written loop.

import memscale
import torch
import torch.nn as nn
 
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()
 
with memscale.optimize(model, optimizer) as ms:
    for epoch in range(num_epochs):
        for inputs, targets in dataloader:
            inputs, targets = inputs.cuda(), targets.cuda()
 
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
 
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
# optimization automatically removed here

Passing the optimizer lets the decision engine sharpen its memory-pressure estimate.

Option B — wrap()

If you prefer the optimization to persist beyond a single block, use wrap() and call detach() when you are done.

import memscale
 
model = memscale.wrap(model)
 
for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
 
memscale.detach(model)   # restore the model for inference / eval

Which to use

optimize()wrap()
Lifetimethe with blockuntil detach()
Cleanupautomaticmanual (detach())
Optimizer-awareyesno
Best fora contained training runa model used across phases

With a custom config

Both accept a Config:

from memscale import Config, OptimizationMode
 
config = Config(mode=OptimizationMode.AGGRESSIVE, use_mixed_precision=True)
 
with memscale.optimize(model, optimizer, config=config):
    train()

Notes

  • MemScale does not change your loss, backward, or optimizer-step calls — write the loop exactly as you normally would.
  • For accurate profiling, pass a representative sample_input.