Transfer learning and fine-tuning

Transfer learning is a common technique for leveraging large models trained on related tasks (i.e. the pretrained model). Here we use ResNet [HZRS15] which is trained on ImageNet consisting of 1M+ images in 1000 categories. To adapt the pretrained model to our task, we use only the feature extractors and retrain a classification head (Fig. 48).

To avoid nullifying the pretrained weights with large random gradients, we first have to train the classification head to convergence, while keeping the weights of the pretrained model fixed. Then, we proceed with fine-tuning where we train the entire model with a very low learning rate, again so that the pretrained weights are gradually changed.

../../../_images/transfer-learning.png

Fig. 48 Training a new classifier over the same convolutional base. Source: Fig 8.12 of [Cho21].

import torchinfo
from torchvision import models

resnet = models.resnet18(pretrained=True)

BATCH_SIZE = 16
torchinfo.summary(resnet, input_size=(BATCH_SIZE, 3, 49, 49))
Hide code cell output
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [16, 1000]                --
├─Conv2d: 1-1                            [16, 64, 25, 25]          9,408
├─BatchNorm2d: 1-2                       [16, 64, 25, 25]          128
├─ReLU: 1-3                              [16, 64, 25, 25]          --
├─MaxPool2d: 1-4                         [16, 64, 13, 13]          --
├─Sequential: 1-5                        [16, 64, 13, 13]          --
│    └─BasicBlock: 2-1                   [16, 64, 13, 13]          --
│    │    └─Conv2d: 3-1                  [16, 64, 13, 13]          36,864
│    │    └─BatchNorm2d: 3-2             [16, 64, 13, 13]          128
│    │    └─ReLU: 3-3                    [16, 64, 13, 13]          --
│    │    └─Conv2d: 3-4                  [16, 64, 13, 13]          36,864
│    │    └─BatchNorm2d: 3-5             [16, 64, 13, 13]          128
│    │    └─ReLU: 3-6                    [16, 64, 13, 13]          --
│    └─BasicBlock: 2-2                   [16, 64, 13, 13]          --
│    │    └─Conv2d: 3-7                  [16, 64, 13, 13]          36,864
│    │    └─BatchNorm2d: 3-8             [16, 64, 13, 13]          128
│    │    └─ReLU: 3-9                    [16, 64, 13, 13]          --
│    │    └─Conv2d: 3-10                 [16, 64, 13, 13]          36,864
│    │    └─BatchNorm2d: 3-11            [16, 64, 13, 13]          128
│    │    └─ReLU: 3-12                   [16, 64, 13, 13]          --
├─Sequential: 1-6                        [16, 128, 7, 7]           --
│    └─BasicBlock: 2-3                   [16, 128, 7, 7]           --
│    │    └─Conv2d: 3-13                 [16, 128, 7, 7]           73,728
│    │    └─BatchNorm2d: 3-14            [16, 128, 7, 7]           256
│    │    └─ReLU: 3-15                   [16, 128, 7, 7]           --
│    │    └─Conv2d: 3-16                 [16, 128, 7, 7]           147,456
│    │    └─BatchNorm2d: 3-17            [16, 128, 7, 7]           256
│    │    └─Sequential: 3-18             [16, 128, 7, 7]           8,448
│    │    └─ReLU: 3-19                   [16, 128, 7, 7]           --
│    └─BasicBlock: 2-4                   [16, 128, 7, 7]           --
│    │    └─Conv2d: 3-20                 [16, 128, 7, 7]           147,456
│    │    └─BatchNorm2d: 3-21            [16, 128, 7, 7]           256
│    │    └─ReLU: 3-22                   [16, 128, 7, 7]           --
│    │    └─Conv2d: 3-23                 [16, 128, 7, 7]           147,456
│    │    └─BatchNorm2d: 3-24            [16, 128, 7, 7]           256
│    │    └─ReLU: 3-25                   [16, 128, 7, 7]           --
├─Sequential: 1-7                        [16, 256, 4, 4]           --
│    └─BasicBlock: 2-5                   [16, 256, 4, 4]           --
│    │    └─Conv2d: 3-26                 [16, 256, 4, 4]           294,912
│    │    └─BatchNorm2d: 3-27            [16, 256, 4, 4]           512
│    │    └─ReLU: 3-28                   [16, 256, 4, 4]           --
│    │    └─Conv2d: 3-29                 [16, 256, 4, 4]           589,824
│    │    └─BatchNorm2d: 3-30            [16, 256, 4, 4]           512
│    │    └─Sequential: 3-31             [16, 256, 4, 4]           33,280
│    │    └─ReLU: 3-32                   [16, 256, 4, 4]           --
│    └─BasicBlock: 2-6                   [16, 256, 4, 4]           --
│    │    └─Conv2d: 3-33                 [16, 256, 4, 4]           589,824
│    │    └─BatchNorm2d: 3-34            [16, 256, 4, 4]           512
│    │    └─ReLU: 3-35                   [16, 256, 4, 4]           --
│    │    └─Conv2d: 3-36                 [16, 256, 4, 4]           589,824
│    │    └─BatchNorm2d: 3-37            [16, 256, 4, 4]           512
│    │    └─ReLU: 3-38                   [16, 256, 4, 4]           --
├─Sequential: 1-8                        [16, 512, 2, 2]           --
│    └─BasicBlock: 2-7                   [16, 512, 2, 2]           --
│    │    └─Conv2d: 3-39                 [16, 512, 2, 2]           1,179,648
│    │    └─BatchNorm2d: 3-40            [16, 512, 2, 2]           1,024
│    │    └─ReLU: 3-41                   [16, 512, 2, 2]           --
│    │    └─Conv2d: 3-42                 [16, 512, 2, 2]           2,359,296
│    │    └─BatchNorm2d: 3-43            [16, 512, 2, 2]           1,024
│    │    └─Sequential: 3-44             [16, 512, 2, 2]           132,096
│    │    └─ReLU: 3-45                   [16, 512, 2, 2]           --
│    └─BasicBlock: 2-8                   [16, 512, 2, 2]           --
│    │    └─Conv2d: 3-46                 [16, 512, 2, 2]           2,359,296
│    │    └─BatchNorm2d: 3-47            [16, 512, 2, 2]           1,024
│    │    └─ReLU: 3-48                   [16, 512, 2, 2]           --
│    │    └─Conv2d: 3-49                 [16, 512, 2, 2]           2,359,296
│    │    └─BatchNorm2d: 3-50            [16, 512, 2, 2]           1,024
│    │    └─ReLU: 3-51                   [16, 512, 2, 2]           --
├─AdaptiveAvgPool2d: 1-9                 [16, 512, 1, 1]           --
├─Linear: 1-10                           [16, 1000]                513,000
==========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 1.99
==========================================================================================
Input size (MB): 0.46
Forward/backward pass size (MB): 37.34
Params size (MB): 46.76
Estimated Total Size (MB): 84.56
==========================================================================================
in_features = resnet.fc.in_features
num_hidden = 256

head = nn.Sequential(
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.BatchNorm1d(in_features),
    nn.Dropout(0.5),
    nn.Linear(in_features, num_hidden),
    nn.ReLU(),
    nn.BatchNorm1d(num_hidden),
    nn.Dropout(0.5),
    nn.Linear(num_hidden, 2),
)

model = nn.Sequential(
    nn.Sequential(*list(resnet.children())[:-2]),
    head
)

Remark. BatchNorm [IS15] on the network head aids with activation and gradient stability. Dropout is also used to regularize the dense layers.

Static features

Freezing the feature extraction layers:

for param in model[0].parameters(): # model[0] = pretrained
    param.requires_grad = False

Setting up the data loaders:

dl_train = DataLoader(Subset(ds_train, torch.arange(32000)), batch_size=BATCH_SIZE, shuffle=True)
dl_valid = DataLoader(Subset(ds_valid, torch.arange(8000)),  batch_size=BATCH_SIZE, shuffle=False)

Training the model using AdamW [LH17] with learning rate 0.001:

epochs = 10
optim = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = OneCycleLR(optim, max_lr=0.01, steps_per_epoch=len(dl_train), epochs=epochs)
trainer = Trainer(model, optim, loss_fn=F.cross_entropy, scheduler=scheduler, device=DEVICE)
trainer.run(epochs=epochs, train_loader=dl_train, valid_loader=dl_valid)
[Epoch: 01/10]    loss: 0.5991  acc: 0.7019    val_loss: 0.5274  val_acc: 0.7422
[Epoch: 02/10]    loss: 0.5862  acc: 0.6944    val_loss: 0.5135  val_acc: 0.7561
[Epoch: 03/10]    loss: 0.5926  acc: 0.6794    val_loss: 0.5142  val_acc: 0.7522
[Epoch: 04/10]    loss: 0.5714  acc: 0.6956    val_loss: 0.5250  val_acc: 0.7376
[Epoch: 05/10]    loss: 0.5760  acc: 0.7069    val_loss: 0.5082  val_acc: 0.7558
[Epoch: 06/10]    loss: 0.5508  acc: 0.7188    val_loss: 0.5039  val_acc: 0.7611
[Epoch: 07/10]    loss: 0.5172  acc: 0.7400    val_loss: 0.4862  val_acc: 0.7691
[Epoch: 08/10]    loss: 0.5345  acc: 0.7362    val_loss: 0.4896  val_acc: 0.7791
[Epoch: 09/10]    loss: 0.5257  acc: 0.7456    val_loss: 0.4851  val_acc: 0.7749
[Epoch: 10/10]    loss: 0.5327  acc: 0.7225    val_loss: 0.4857  val_acc: 0.7744
Hide code cell outputs
def plot_training_history(trainer):
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    
    num_epochs = len(trainer.valid_log["accs"])
    num_steps_per_epoch = len(trainer.train_log["accs"]) // num_epochs
    ax[0].plot(trainer.train_log["loss"], alpha=0.3, color="C0")
    ax[1].plot(trainer.train_log["accs"], alpha=0.3, color="C0")
    ax[0].plot(trainer.train_log["loss_avg"], label="train", color="C0")
    ax[1].plot(trainer.train_log["accs_avg"], label="train", color="C0")
    ax[0].plot(list(range(num_steps_per_epoch, (num_epochs + 1) * num_steps_per_epoch, num_steps_per_epoch)), trainer.valid_log["loss"], label="valid", color="C1")
    ax[1].plot(list(range(num_steps_per_epoch, (num_epochs + 1) * num_steps_per_epoch, num_steps_per_epoch)), trainer.valid_log["accs"], label="valid", color="C1")
    ax[0].set_xlabel("step")
    ax[0].set_ylabel("loss")
    ax[0].grid(linestyle="dashed", alpha=0.3)
    ax[1].set_xlabel("step")
    ax[1].set_ylabel("accuracy")
    ax[1].grid(linestyle="dashed", alpha=0.3)
    ax[1].legend()
    ax[0].set_ylim(0, max(trainer.train_log["loss"]))
    ax[1].set_ylim(0, 1)
    ax[0].ticklabel_format(axis="x", style="sci", scilimits=(3, 3))
    ax[1].ticklabel_format(axis="x", style="sci", scilimits=(3, 3))
    
    fig.tight_layout();
plot_training_history(trainer)
../../../_images/1ea11f90411210d48635df7159218e82c37facb3c94200abbe7490a82bf7a86a.svg

Remark. The validation step accumulates results over after an epoch for a fixed set of weights. This simulates inference performance if we load the trained model at that checkpoint. On the other hand, train metrics are expensive since the training dataset is large. Instead, these are accumulated at each step as an average with the previous steps.


Fine-tuning

Unfreezing the pretrained model layers. Note that we set small learning rates:

for param in model[0].parameters():
    param.requires_grad = True

# 100x smaller lr (both optim and scheduler)
epochs = 20
optim = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.05)
scheduler = OneCycleLR(optim, max_lr=0.0001, steps_per_epoch=len(dl_train), epochs=epochs)
trainer_ft = Trainer(model, optim, loss_fn=F.cross_entropy, scheduler=scheduler, device=DEVICE)
trainer_ft.run(epochs=epochs, train_loader=dl_train, valid_loader=dl_valid)
[Epoch: 01/20]    loss: 0.4978  acc: 0.7706    val_loss: 0.4265  val_acc: 0.8160
[Epoch: 02/20]    loss: 0.4262  acc: 0.8113    val_loss: 0.3741  val_acc: 0.8434
[Epoch: 03/20]    loss: 0.4110  acc: 0.8175    val_loss: 0.3544  val_acc: 0.8540
[Epoch: 04/20]    loss: 0.3923  acc: 0.8331    val_loss: 0.3363  val_acc: 0.8626
[Epoch: 05/20]    loss: 0.3599  acc: 0.8469    val_loss: 0.3388  val_acc: 0.8596
[Epoch: 06/20]    loss: 0.3406  acc: 0.8581    val_loss: 0.3171  val_acc: 0.8699
[Epoch: 07/20]    loss: 0.3875  acc: 0.8469    val_loss: 0.3338  val_acc: 0.8685
[Epoch: 08/20]    loss: 0.3433  acc: 0.8650    val_loss: 0.3190  val_acc: 0.8856
[Epoch: 09/20]    loss: 0.3260  acc: 0.8712    val_loss: 0.3132  val_acc: 0.8698
[Epoch: 10/20]    loss: 0.3074  acc: 0.8831    val_loss: 0.2931  val_acc: 0.8825
[Epoch: 11/20]    loss: 0.3176  acc: 0.8806    val_loss: 0.2966  val_acc: 0.8868
[Epoch: 12/20]    loss: 0.3048  acc: 0.8775    val_loss: 0.2904  val_acc: 0.8848
[Epoch: 13/20]    loss: 0.2680  acc: 0.8894    val_loss: 0.3251  val_acc: 0.8842
[Epoch: 14/20]    loss: 0.2866  acc: 0.8825    val_loss: 0.2846  val_acc: 0.8896
[Epoch: 15/20]    loss: 0.2699  acc: 0.8919    val_loss: 0.2710  val_acc: 0.8958
[Epoch: 16/20]    loss: 0.2517  acc: 0.9062    val_loss: 0.2798  val_acc: 0.8948
[Epoch: 17/20]    loss: 0.2405  acc: 0.9100    val_loss: 0.2753  val_acc: 0.8981
[Epoch: 18/20]    loss: 0.2378  acc: 0.9031    val_loss: 0.2645  val_acc: 0.8969
[Epoch: 19/20]    loss: 0.2349  acc: 0.9144    val_loss: 0.2649  val_acc: 0.8990
[Epoch: 20/20]    loss: 0.2411  acc: 0.9163    val_loss: 0.2642  val_acc: 0.8986
Hide code cell source
loss = trainer.train_log["loss"] + trainer_ft.train_log["loss"]
accs = trainer.train_log["accs"] + trainer_ft.train_log["accs"]
loss_avg = trainer.train_log["loss_avg"] + trainer_ft.train_log["loss_avg"]
accs_avg = trainer.train_log["accs_avg"] + trainer_ft.train_log["accs_avg"]
val_loss = trainer.valid_log["loss"] + trainer_ft.valid_log["loss"]
val_accs = trainer.valid_log["accs"] + trainer_ft.valid_log["accs"]
num_epochs = len(val_loss)
num_steps_per_epoch = len(loss) // num_epochs

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].plot(loss, alpha=0.3, color="C0")
ax[1].plot(accs, alpha=0.3, color="C0")
ax[0].plot(loss_avg, label="train", color="C0")
ax[1].plot(accs_avg, label="train", color="C0")
ax[0].plot(list(range(num_steps_per_epoch, (num_epochs + 1) * num_steps_per_epoch, num_steps_per_epoch)), val_loss, label="valid", color="C1")
ax[1].plot(list(range(num_steps_per_epoch, (num_epochs + 1) * num_steps_per_epoch, num_steps_per_epoch)), val_accs, label="valid", color="C1")
ax[0].axvline(len(trainer.train_log["loss"]), color="black", linestyle="dotted", label="[fine-tuning]")
ax[1].axvline(len(trainer.train_log["loss"]), color="black", linestyle="dotted", label="[fine-tuning]")
ax[0].set_xlabel("step")
ax[0].set_ylabel("loss")
ax[0].grid(linestyle="dashed", alpha=0.3)
ax[0].set_ylim(0, max(loss))
ax[1].set_xlabel("step")
ax[1].set_ylabel("accuracy")
ax[1].grid(linestyle="dashed", alpha=0.3)
ax[1].set_ylim(0, 1)
ax[1].legend(framealpha=1.0)
ax[0].ticklabel_format(axis="x", style="sci", scilimits=(3, 3))
ax[1].ticklabel_format(axis="x", style="sci", scilimits=(3, 3))
fig.tight_layout();
../../../_images/159d96e951ddf068838b7b18e2e0f1b8b63e5e0fc4ca63c6ff05d7a2ebda8b1e.svg

Fig. The weights of the classification head are trained on the outputs of the pretrained ResNet model with fixed weights. After the classification head forms proper weights, the pretrained weights are unfreezed, and trained with small LR. Performance improves at a faster rate, but also fluctuates more at this stage.


Remarks. The model overfits and the validation curves diverge very early in the training when we turn off data augmentation (i.e. model memorizes the training data). Data augmentation prevents this by adding noise in the input. The dense layers also does not train well without BN. Finally, since the data is slightly imbalanced, we should also look at the confusion matrix and PR curve.

Inference

Note that calling the model on an input x with large batch dim may cause memory error for large models. Hence, using a data loader for offline batch inference is good practice:

class InputDataLoader(DataLoader):
    def __init__(self, data_loader):
        self.data_loader = data_loader

    def __iter__(self):
        for batch in self.data_loader:
            yield batch[0]


@torch.inference_mode()
def batch_predict(trainer: Trainer, input_loader: DataLoader):
    with eval_context(trainer.model):
        preds = [trainer(x) for x in input_loader]
        preds = torch.cat(preds, dim=0)
    return preds


pred = batch_predict(trainer, InputDataLoader(dl_valid))
print(pred.shape)
print(pred)
torch.Size([8000, 2])
tensor([[ 0.0836, -0.0971],
        [ 1.1856, -1.1767],
        [ 0.8685, -0.8674],
        ...,
        [ 1.0003, -0.9968],
        [-2.4745,  2.4580],
        [-2.5150,  2.4989]], device='mps:0')

This should be equal to the final validation accuracy:

y = torch.cat([batch[1] for batch in dl_valid], dim=0)
print((pred.argmax(dim=1) == y.to(DEVICE)).float().mean().item())
print(trainer_ft.evaluate(dl_valid)["accs"])
print(trainer_ft.valid_log["accs"][-1]) # or look at final valid log
0.8986250162124634
0.898625
0.898625

Note that input from our data loaders come transformed. For processing raw images, we have to transform the inputs in eval mode:

file_path = "data/histopathologic-cancer-detection/test/0a0a1f3867f41e02353afcaf503f63be1bdd35ec.tif"
test_data = cv2.imread(file_path)
print(trainer.predict(transform_infer(test_data).unsqueeze(0)))
tensor([[ 1.4946, -1.4783]], device='mps:0')