Training from scratch

Our task is to learn the ff. dataset consisting of noisy measurements around a quadratic curve:

%config InlineBackend.figure_format = "svg"
import numpy as np
import matplotlib.pyplot as plt
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

N = 1500
X = np.linspace(-1, 5, N)
Y = X ** 2 + 0.5 * np.random.normal(size=N, scale=3)

Helper for loading the samples:

import random
random.seed(RANDOM_SEED)

class DataLoader:
    def __init__(self, dataset, batch_size):
        """Iterate over a partition of the dataset."""
        self.batch_size = batch_size
        self.dataset = [(Node(x), Node(y)) for x, y in dataset]
    
    def load(self):
        return random.sample(self.dataset, self.batch_size)

    def __len__(self):
        return len(self.dataset)

The function optim_step implements one step of SGD. Note that gradients accumulate, so we can implement batch size. But for simplicity, we only consider one training instance per update. Here loss_fn is the MSE between two nodes.

def optim_step(model, lr=1.0):
    for p in model.parameters():
        p.data -= lr * p.grad 

def loss_fn(y_pred, y_true):
    return (y_pred - y_true) ** 2

Running the training algorithm:

from tqdm.notebook import tqdm

def train(model, dataset, steps, lr=0.1, batch_size=32):
    dataloader = DataLoader(dataset, batch_size)
    history = []
    for _ in tqdm(range(steps)):
        for x, y in dataloader.load():
            loss = loss_fn(model([x]), y)
            loss.backward()
            optim_step(model, lr=lr)
            
            model.zero_grad()
            history.append(loss.data)
    
    return history

dataset = list(zip(X, Y))
model = MLP(1, [8, 4, 1], "tanh")
losses = train(model, dataset, lr=0.003, batch_size=8, steps=5000)

Loss curve moving average decreasing:

w = 50
loss_avg = np.array(losses).reshape(-1, w).mean(axis=1)

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(loss_avg, label="train")
ax.grid(linestyle="dotted", alpha=0.6)
ax.set_ylabel("MSE")
ax.set_xlabel("steps")
ax.legend();
../../../_images/6d8d5e3f9245bdf1a937f289978c1a66363e62f5e11ad30e37fdba93d45d49f6.svg

Model learned: ヾ( ˃ᴗ˂ )◞ • *✰

Hide code cell source
plt.figure(figsize=(6, 4))
plt.scatter(X, Y, label="data", s=2, alpha=0.7, color="black")
plt.plot(X, [model([Node(x)]).data for x in X], label="model", color="C1", linewidth=2)

plt.ylabel("y")
plt.xlabel("x")
plt.grid(linestyle="dotted")
plt.legend();
../../../_images/d3d5db443ce7187e6e33ab2ac61d755a64834a11f97abdec0ef879e07a0ecfed.svg

Model graph visualization:

x = Node(1.0)
pred = model([x])
draw_graph(pred)
../../../_images/bc3f8ec46878b44f286419c2882d79eef356dcc5ad23eada3499f4bd904fcc00.svg