API Referenceoptimize()

optimize()

A context manager that applies MemScale optimization for the duration of a with block. Use it for custom training loops where you do not have a Trainer object.

Signature

memscale.optimize(
    model,
    optimizer=None,
    config=None,
    sample_input=None,
)

optimize() is a context manager (@contextmanager).

Parameters

ParameterTypeDefaultDescription
modelnn.ModuleThe model to optimize.
optimizertorch.optim.OptimizerNoneThe optimizer. When provided, it sharpens the memory-pressure estimate used by the decision engine.
configConfigNoneOptimization settings. None uses Config() (balanced mode).
sample_inputtorch.TensorNoneA representative input tensor. Improves profiling accuracy.

Yields

The MemScale executor for the wrapped run. Most training loops do not need to touch it, but it is available as the as target.

When the with block exits, MemScale automatically detaches — you do not call detach() yourself.

Example

import memscale
import torch
 
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
 
with memscale.optimize(model, optimizer) as ms:
    for batch in dataloader:
        loss = model(**batch).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
# optimization is automatically removed here

With a custom config

from memscale import optimize, Config, OptimizationMode
 
config = Config(mode=OptimizationMode.AGGRESSIVE)
 
with optimize(model, optimizer, config=config):
    train()

optimize() vs. wrap()

wrap()optimize()
Formfunctioncontext manager
Scopeuntil you call detach()the with block
Optimizer-awarenoyes (optional)
Best forTrainer or persistent modelshand-written training loops