Trainer engine
To separate concerns during model training, we define a trainer engine. For example, this defines an eval_context
to automatically set the model to eval mode at entry, and back to the default train mode at exit. This is useful for layers such as BN and Dropout which have different behaviors at train and test times. LR schedulers and callbacks are also implemented. Currently, these are called at the end of each training step (it is easy to extend this class to implement epoch end callbacks).
from tqdm.notebook import tqdm
from contextlib import contextmanager
from torch.utils.data import DataLoader
@contextmanager
def eval_context(model):
"""Temporarily set to eval mode inside context."""
is_train = model.training
model.eval()
try:
yield
finally:
model.train(is_train)
class Trainer:
def __init__(self,
model, optim, loss_fn, scheduler=None, callbacks=[],
device=DEVICE, verbose=True
):
self.model = model.to(device)
self.optim = optim
self.device = device
self.loss_fn = loss_fn
self.train_log = {"loss": [], "accs": [], "loss_avg": [], "accs_avg": []}
self.valid_log = {"loss": [], "accs": []}
self.verbose = verbose
self.scheduler = scheduler
self.callbacks = callbacks
def __call__(self, x):
return self.model(x.to(self.device))
def forward(self, batch):
x, y = batch
x = x.to(self.device)
y = y.to(self.device)
return self.model(x), y
def train_step(self, batch):
preds, y = self.forward(batch)
accs = (preds.argmax(dim=1) == y).float().mean()
loss = self.loss_fn(preds, y)
loss.backward()
self.optim.step()
self.optim.zero_grad()
return {"loss": loss, "accs": accs}
@torch.inference_mode()
def valid_step(self, batch):
preds, y = self.forward(batch)
accs = (preds.argmax(dim=1) == y).float().sum()
loss = self.loss_fn(preds, y, reduction="sum")
return {"loss": loss, "accs": accs}
def run(self, epochs, train_loader, valid_loader):
for e in tqdm(range(epochs)):
for batch in train_loader:
# optim and lr step
output = self.train_step(batch)
if self.scheduler:
self.scheduler.step()
# step callbacks
for callback in self.callbacks:
callback()
# logs @ train step
steps_per_epoch = len(train_loader)
w = int(0.05 * steps_per_epoch)
self.train_log["loss"].append(output["loss"].item())
self.train_log["accs"].append(output["accs"].item())
self.train_log["loss_avg"].append(np.mean(self.train_log["loss"][-w:]))
self.train_log["accs_avg"].append(np.mean(self.train_log["accs"][-w:]))
# logs @ epoch
output = self.evaluate(valid_loader)
self.valid_log["loss"].append(output["loss"])
self.valid_log["accs"].append(output["accs"])
if self.verbose:
print(f"[Epoch: {e+1:>0{int(len(str(epochs)))}d}/{epochs}] loss: {self.train_log['loss_avg'][-1]:.4f} acc: {self.train_log['accs_avg'][-1]:.4f} val_loss: {self.valid_log['loss'][-1]:.4f} val_acc: {self.valid_log['accs'][-1]:.4f}")
def evaluate(self, data_loader):
with eval_context(self.model):
valid_loss = 0.0
valid_accs = 0.0
for batch in data_loader:
output = self.valid_step(batch)
valid_loss += output["loss"].item()
valid_accs += output["accs"].item()
return {
"loss": valid_loss / len(data_loader.dataset),
"accs": valid_accs / len(data_loader.dataset)
}
@torch.inference_mode()
def predict(self, x: torch.Tensor):
with eval_context(self.model):
return self(x)
The predict
method is suited for inference over one transformed mini-batch. A model call over a large input tensor may cause memory error. The model does not generate a computational graph to conserve memory and calls the model with layers in eval mode. For large batches, one should use batch_predict
which is the same but takes in a data loader with transforms.
model = nn.Sequential(nn.Linear(3, 10), nn.Dropout(1.0))
trainer = Trainer(model, optim=None, scheduler=None, loss_fn=None)
# inference mode using eval_context
x = torch.ones(size=(1, 3), requires_grad=True)
print(f"__call__ {(trainer(x) > 0).float().mean():.3f}")
print(f"predict {(trainer.predict(x) > 0).float().mean():.3f}")
__call__ 0.000
predict 0.500
Checking computational graph generation:
y = trainer(x)
z = trainer.predict(x)
print("__call__ ", y.requires_grad)
print("predict ", z.requires_grad)
__call__ True
predict False