optimize()
A context manager that applies MemScale optimization for the duration of a
with block. Use it for custom training loops where you do not have a
Trainer object.
Signature
memscale.optimize(
model,
optimizer=None,
config=None,
sample_input=None,
)optimize() is a context manager (@contextmanager).
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
model | nn.Module | — | The model to optimize. |
optimizer | torch.optim.Optimizer | None | The optimizer. When provided, it sharpens the memory-pressure estimate used by the decision engine. |
config | Config | None | Optimization settings. None uses Config() (balanced mode). |
sample_input | torch.Tensor | None | A representative input tensor. Improves profiling accuracy. |
Yields
The MemScale executor for the wrapped run. Most training loops do not
need to touch it, but it is available as the as target.
When the with block exits, MemScale automatically detaches — you do not
call detach() yourself.
Example
import memscale
import torch
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
with memscale.optimize(model, optimizer) as ms:
for batch in dataloader:
loss = model(**batch).loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
# optimization is automatically removed hereWith a custom config
from memscale import optimize, Config, OptimizationMode
config = Config(mode=OptimizationMode.AGGRESSIVE)
with optimize(model, optimizer, config=config):
train()optimize() vs. wrap()
wrap() | optimize() | |
|---|---|---|
| Form | function | context manager |
| Scope | until you call detach() | the with block |
| Optimizer-aware | no | yes (optional) |
| Best for | Trainer or persistent models | hand-written training loops |