Model

Combining convolution and pooling layers into a deep model:

import torchsummary

mnist_model = lambda: nn.Sequential(
    nn.Conv2d(1, 32, 3, 1, 1),
    nn.SELU(),
    nn.MaxPool2d(2, 2),
    
    nn.Conv2d(32, 32, 5, 1, 0),
    nn.SELU(),
    nn.MaxPool2d(2, 2),
    
    nn.Flatten(),
    nn.Linear(800, 256), nn.SELU(), nn.Dropout(0.5),
    nn.Linear(256, 10)
)
torchsummary.summary(mnist_model(), (1, 28, 28))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 28, 28]             320
              SELU-2           [-1, 32, 28, 28]               0
         MaxPool2d-3           [-1, 32, 14, 14]               0
            Conv2d-4           [-1, 32, 10, 10]          25,632
              SELU-5           [-1, 32, 10, 10]               0
         MaxPool2d-6             [-1, 32, 5, 5]               0
           Flatten-7                  [-1, 800]               0
            Linear-8                  [-1, 256]         205,056
              SELU-9                  [-1, 256]               0
          Dropout-10                  [-1, 256]               0
           Linear-11                   [-1, 10]           2,570
================================================================
Total params: 233,578
Trainable params: 233,578
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.50
Params size (MB): 0.89
Estimated Total Size (MB): 1.39
----------------------------------------------------------------

Remark. We use SELU activation [KUMH17] for fun. Note that we also used Dropout [SHK+14] as regularization for the dense layers. Observe that number of parameters due to convolutions is small relative to the final network size (i.e. only ~10%)!


Setting up MNIST data loaders:

from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import random_split, DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x / 255.)
])

g = torch.Generator().manual_seed(RANDOM_SEED)
ds = MNIST(root=DATASET_DIR, download=False, transform=transform)
ds_train, ds_valid = random_split(ds, [55000, 5000], generator=g)
dl_train = DataLoader(ds_train, batch_size=32, shuffle=True) # (!)
dl_valid = DataLoader(ds_valid, batch_size=32, shuffle=False)

Remark. shuffle=True is important for SGD training. The model will have low validation score when looping through the samples in the same order during training. This may be due to cyclic behavior in the updates (e.g. cancelling out).