Model
Combining convolution and pooling layers into a deep model:
import torchsummary
mnist_model = lambda: nn.Sequential(
nn.Conv2d(1, 32, 3, 1, 1),
nn.SELU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 32, 5, 1, 0),
nn.SELU(),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(800, 256), nn.SELU(), nn.Dropout(0.5),
nn.Linear(256, 10)
)
torchsummary.summary(mnist_model(), (1, 28, 28))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 32, 28, 28] 320
SELU-2 [-1, 32, 28, 28] 0
MaxPool2d-3 [-1, 32, 14, 14] 0
Conv2d-4 [-1, 32, 10, 10] 25,632
SELU-5 [-1, 32, 10, 10] 0
MaxPool2d-6 [-1, 32, 5, 5] 0
Flatten-7 [-1, 800] 0
Linear-8 [-1, 256] 205,056
SELU-9 [-1, 256] 0
Dropout-10 [-1, 256] 0
Linear-11 [-1, 10] 2,570
================================================================
Total params: 233,578
Trainable params: 233,578
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.50
Params size (MB): 0.89
Estimated Total Size (MB): 1.39
----------------------------------------------------------------
Remark. We use SELU activation [KUMH17] for fun. Note that we also used Dropout [SHK+14] as regularization for the dense layers. Observe that number of parameters due to convolutions is small relative to the final network size (i.e. only ~10%)!
Setting up MNIST data loaders:
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import random_split, DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x / 255.)
])
g = torch.Generator().manual_seed(RANDOM_SEED)
ds = MNIST(root=DATASET_DIR, download=False, transform=transform)
ds_train, ds_valid = random_split(ds, [55000, 5000], generator=g)
dl_train = DataLoader(ds_train, batch_size=32, shuffle=True) # (!)
dl_valid = DataLoader(ds_valid, batch_size=32, shuffle=False)
Remark. shuffle=True
is important for SGD training. The model will have low validation score when looping through the samples in the same order during training. This may be due to cyclic behavior in the updates (e.g. cancelling out).