{ "cells": [ { "cell_type": "markdown", "id": "82d152bf", "metadata": { "papermill": { "duration": 0.007271, "end_time": "2024-11-27T10:40:39.115477", "exception": false, "start_time": "2024-11-27T10:40:39.108206", "status": "completed" }, "tags": [] }, "source": [ "# LR scheduling" ] }, { "cell_type": "code", "execution_count": 1, "id": "e84b0f26", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:40:39.130563Z", "iopub.status.busy": "2024-11-27T10:40:39.130399Z", "iopub.status.idle": "2024-11-27T10:40:42.047581Z", "shell.execute_reply": "2024-11-27T10:40:42.047231Z" }, "papermill": { "duration": 2.927674, "end_time": "2024-11-27T10:40:42.050145", "exception": false, "start_time": "2024-11-27T10:40:39.122471", "status": "completed" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "from chapter import *" ] }, { "cell_type": "markdown", "id": "3bddf1e0", "metadata": { "papermill": { "duration": 0.001685, "end_time": "2024-11-27T10:40:42.054569", "exception": false, "start_time": "2024-11-27T10:40:42.052884", "status": "completed" }, "tags": [] }, "source": [ "Training the model with [one-cycle LR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html) schedule {cite}`super-convergence-resnet`. The one-cycle policy anneals the learning rate from an base learninge rate to a set maximum learning rate, and then, from that maximum learning rate, to a minimum learning rate much lower than the base learninge rate. Momentum is also annealed inversely to the learning rate which is necessary for stability." ] }, { "cell_type": "code", "execution_count": 2, "id": "86b1365d", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:40:42.059148Z", "iopub.status.busy": "2024-11-27T10:40:42.058862Z", "iopub.status.idle": "2024-11-27T10:40:42.125789Z", "shell.execute_reply": "2024-11-27T10:40:42.125473Z" }, "papermill": { "duration": 0.07076, "end_time": "2024-11-27T10:40:42.127128", "exception": false, "start_time": "2024-11-27T10:40:42.056368", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
from torch.optim.lr_scheduler import OneCycleLR\n",
"\n",
"class SchedulerStatsCallback:\n",
" def __init__(self, optim):\n",
" self.lr = []\n",
" self.momentum = []\n",
" self.optim = optim\n",
"\n",
" def __call__(self):\n",
" self.lr.append(self.optim.param_groups[0]["lr"])\n",
" self.momentum.append(self.optim.param_groups[0]["betas"][0])\n",
"\n",
"epochs = 3\n",
"model = mnist_model().to(DEVICE)\n",
"loss_fn = F.cross_entropy\n",
"optim = torch.optim.AdamW(model.parameters(), lr=0.001)\n",
"scheduler = OneCycleLR(optim, max_lr=0.01, steps_per_epoch=len(dl_train), epochs=epochs)\n",
"scheduler_stats = SchedulerStatsCallback(optim)\n",
"trainer = Trainer(model, optim, loss_fn, scheduler, callbacks=[scheduler_stats])\n",
"