LR scheduling

Training the model with one-cycle LR schedule [ST17]. The one-cycle policy anneals the learning rate from an base learninge rate to a set maximum learning rate, and then, from that maximum learning rate, to a minimum learning rate much lower than the base learninge rate. Momentum is also annealed inversely to the learning rate which is necessary for stability.

from torch.optim.lr_scheduler import OneCycleLR

class SchedulerStatsCallback:
    def __init__(self, optim):
        self.lr = []
        self.momentum = []
        self.optim = optim

    def __call__(self):
        self.lr.append(self.optim.param_groups[0]["lr"])
        self.momentum.append(self.optim.param_groups[0]["betas"][0])

epochs = 3
model = mnist_model().to(DEVICE)
loss_fn = F.cross_entropy
optim = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = OneCycleLR(optim, max_lr=0.01, steps_per_epoch=len(dl_train), epochs=epochs)
scheduler_stats = SchedulerStatsCallback(optim)
trainer = Trainer(model, optim, loss_fn, scheduler, callbacks=[scheduler_stats])

Run training:

trainer.run(epochs=epochs, train_loader=dl_train, valid_loader=dl_valid)
[Epoch: 1/3]    loss: 1.0642  acc: 0.8922    val_loss: 0.6437  val_acc: 0.9356
[Epoch: 2/3]    loss: 0.3215  acc: 0.9518    val_loss: 0.1757  val_acc: 0.9728
[Epoch: 3/3]    loss: 0.1740  acc: 0.9676    val_loss: 0.1008  val_acc: 0.9794

Remark. After trying out other activations… SELU performance is surprising. It also trains really fast. That self-normalizing bit is no joke.

Hide code cell source
fig, (ax1, ax3) = plt.subplots(2, 1, figsize=(8, 4), gridspec_kw={'height_ratios': [4, 1]})

AX1_COLOR = "C0"
AX2_COLOR = "C1"
ax1.plot(trainer.train_log["loss"], color=AX1_COLOR, linewidth=1, alpha=0.6)
ax1.plot(trainer.train_log["loss_avg"], color=AX1_COLOR, linewidth=2)
ax1.grid(axis="both", linestyle="dotted", alpha=0.8)
ax1.set_xlabel("step")

ax2 = ax1.twinx()
ax2.plot(trainer.train_log["accs"], color=AX2_COLOR, linewidth=1, alpha=0.6)
ax2.plot(trainer.train_log["accs_avg"], color=AX2_COLOR, linewidth=2)

ax1.set_ylabel("Train loss")
ax2.set_ylabel("Train accuracy (%)")
ax1.yaxis.label.set_color(AX1_COLOR)
ax2.yaxis.label.set_color(AX2_COLOR)

ax4 = ax3.twinx()
ax3.plot(scheduler_stats.lr, color="black", label="lr", linewidth=2)
ax4.plot(scheduler_stats.momentum, color="blue", label=r"$\beta_1$", linewidth=2)
ax4.grid(linestyle="dotted")
ax3.set_xlabel("step")
ax3.set_ylabel("lr")
ax4.set_ylabel(r"$\beta_1$")
ax3.yaxis.label.set_color("black")
ax4.yaxis.label.set_color("blue")

fig.tight_layout()
../../../_images/d390e8f180f9df4f9c47ea8d51bf8cfe3d73bd309f4684506dc6e2afe1bf6f78.svg

Figure. Note peak in train loss as LR increases to max_lr set at initialization, and the decreasing noise as the LR decreases at the end of training. This works similar our previous LR finder which is a parameter-free method for finding a good base lr.

The bump in learning rate occurs over a wide duration during training, so that the optimizer avoids many sharp minima. This allows the network to train with less epochs — increasing the number of epochs increases the exploration time (not just convergence time).