In this article we will explore stochastic weighted average method. We will use the CoverType dataset. The goal is to predict the forest cover type from cartographic variables only (no remotely sensed data), of which rhere are 7 forest:

Independent variables were derived from data originally obtained from US Geological Survey (USGS) and USFS data. Data is in raw form (not scaled) and contains binary (0 or 1) columns of data for qualitative independent variables (wilderness areas and soil types). The labels represent the dominant species of trees found on a 30m × 30m forest cell, as determinedfrom the US Forest Service (USFS) Region 2 Resource Information System (RIS) data. We will use thew data as provided by the scikit-learn function fetch_covtype().

import matplotlib.pylab as plt
plt.style.use('ggplot')
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.datasets import fetch_covtype
data = fetch_covtype()
df = pd.concat((
    pd.DataFrame(data.data, columns=data.feature_names),
    pd.DataFrame(data.target - 1, columns=data.target_names)
), axis='columns')
print(f"Dataset contains {df.shape[0]:,} rows and {df.shape[1]:,} columns.")
df.head()
Dataset contains 581,012 rows and 55 columns.
Elevation Aspect Slope Horizontal_Distance_To_Hydrology Vertical_Distance_To_Hydrology Horizontal_Distance_To_Roadways Hillshade_9am Hillshade_Noon Hillshade_3pm Horizontal_Distance_To_Fire_Points ... Soil_Type_31 Soil_Type_32 Soil_Type_33 Soil_Type_34 Soil_Type_35 Soil_Type_36 Soil_Type_37 Soil_Type_38 Soil_Type_39 Cover_Type
0 2596.0 51.0 3.0 258.0 0.0 510.0 221.0 232.0 148.0 6279.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4
1 2590.0 56.0 2.0 212.0 -6.0 390.0 220.0 235.0 151.0 6225.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4
2 2804.0 139.0 9.0 268.0 65.0 3180.0 234.0 238.0 135.0 6121.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
3 2785.0 155.0 18.0 242.0 118.0 3090.0 238.0 238.0 122.0 6211.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
4 2595.0 45.0 2.0 153.0 -1.0 391.0 220.0 234.0 150.0 6172.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4

5 rows × 55 columns

The dataset is of medium size; all entries are valid.

missing_data = df.isnull().sum()
assert len(missing_data[missing_data > 0]) == 0

missing_data = df.isna().sum()
assert len(missing_data[missing_data > 0]) == 0

We start as usual with some exploratory data analysis. Most columns contain zeros and ones only; we start from the ones that don’t and plot them, of which we have ten.

numeric_cols = df.select_dtypes(include=[np.number]).columns

numeric_cols = [col for col in numeric_cols if not df[col].isin([0, 1]).all() and col != 'Cover_Type']
n_cols = len(numeric_cols)

assert n_cols > 0, "No numeric columns to visualize"

n_rows = (n_cols + 2) // 3
fig, axes = plt.subplots(n_rows, 3, figsize=(15, 5 * n_rows))
axes = axes.flatten() if n_rows > 1 else [axes]

target = 'Cover_Type'
for i, col in enumerate(numeric_cols):
    sns.boxplot(data=df, x=target, y=col, ax=axes[i])
    axes[i].set_title(f'{col}')

for i in range(n_cols, len(axes)):
    axes[i].set_visible(False)

plt.tight_layout()

png

n_rows = (n_cols + 2) // 3
fig, axes = plt.subplots(n_rows, 3, figsize=(15, 5 * n_rows))
axes = axes.flatten() if n_rows > 1 else [axes]

for i, col in enumerate(numeric_cols):
    axes[i].hist(df[col], bins=30, alpha=0.7, color='skyblue', edgecolor='black')
    axes[i].set_title(f'Distribution of {col}')
    axes[i].set_xlabel(col)
    axes[i].set_ylabel('Frequency')

for i in range(len(numeric_cols), len(axes)):
    axes[i].set_visible(False)

plt.tight_layout()

png

The wilderness area columns are one-hot encoding.

columns = ['Wilderness_Area_0', 'Wilderness_Area_1', 'Wilderness_Area_2', 'Wilderness_Area_3']
# make sure they are a one-hot encoding
assert sum(df[columns].sum(axis='columns') != 1) == 0

result_data = []
for col in columns:
    percentages = (
        df.groupby(target)[col]
            .mean()
            .reindex(range(1, 8), fill_value=0)
            .values
    )
    result_data.append(percentages)

result_df = pd.DataFrame(
    result_data, 
    index=columns, 
    columns=[str(i) for i in range(0, 7)]
)

plt.figure(figsize=(8, len(columns) * 0.5 + 2))
sns.heatmap(result_df, annot=True, fmt='.4f', cmap='Blues')
plt.xlabel('Target Class')
plt.ylabel('OneHot Variable')
plt.title('Wilderness area percentage')
plt.tight_layout()

png

Same is true of the soil type, which is divided into 40 subtypes.

columns = [f'Soil_Type_{i}' for i in range(40)]
# make sure they are a one-hot encoding
assert sum(df[columns].sum(axis='columns') != 1) == 0

result_data = []
for col in columns:
    percentages = (
        df.groupby(target)[col]
            .mean()
            .reindex(range(1, 8), fill_value=0)
            .values
    )
    result_data.append(percentages)

result_df = pd.DataFrame(
    result_data, 
    index=columns, 
    columns=[str(i) for i in range(0, 7)]
)

plt.figure(figsize=(8, len(columns) * 0.5 + 2))
sns.heatmap(result_df, annot=True, fmt='.4f', cmap='Blues')
plt.xlabel('Target Class')
plt.ylabel('OneHot Variable')
plt.title('Soil type percentage')
plt.tight_layout()

png

The problem is heavily unbalanced, with almost half of the entries on class #1. This means that a model that always predicts #1 will achieve 48.76% accuracy, and by predicting the first two classes an 87.22% accuracy.

for k, v in df.Cover_Type.value_counts().sort_index().items():
    print(f"Class {k}: {v:>8,} entries ({v / len(df):.2%})")
Class 0:  211,840 entries (36.46%)
Class 1:  283,301 entries (48.76%)
Class 2:   35,754 entries (6.15%)
Class 3:    2,747 entries (0.47%)
Class 4:    9,493 entries (1.63%)
Class 5:   17,367 entries (2.99%)
Class 6:   20,510 entries (3.53%)
from sklearn.model_selection import train_test_split

X = data.data.astype(np.float32)
y = (data.target - 1).astype(np.int64)

# train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train).astype(np.float32)
X_test  = scaler.transform(X_test).astype(np.float32)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import balanced_accuracy_score

train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_ds  = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))

train_loader = DataLoader(train_ds, batch_size=1_024, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=1_024)
def get_device():
    if torch.cuda.is_available():
        return 'cuda'
    if torch.mps.is_available():
        return 'mps'
    return 'cpu'

Since the goal of this article is to gather experience with stochastic weight averaging, we use as baseline a simple and small neural network model. A slightly more sophisticated model would improve performances but we’ll stick to the one below.

class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(54, 256),
            nn.ReLU(),
            # nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(),
            # nn.BatchNorm1d(512),
            nn.Linear(128, 64),
            nn.ReLU(),
            # nn.BatchNorm1d(64),
            nn.Linear(64, 7)
        )
    def forward(self, x):
        return self.model(x)

device = get_device()
model = MLP().to(device)
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(
    model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4
)

epochs = 60
swa_start = 30

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=swa_start)

swa_model = AveragedModel(model).to(device)
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

def evaluate(model, loader, device="cuda"):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for Xb, yb in loader:
            Xb, yb = Xb.to(device), yb.to(device)
            preds = model(Xb).argmax(1)

            all_preds.append(preds.cpu().numpy())
            all_labels.append(yb.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    overall_acc = (all_preds == all_labels).mean()
    balanced_acc = balanced_accuracy_score(all_labels, all_preds)

    # Per-class accuracy (optional)
    class_acc = {}
    for cls in np.unique(all_labels):
        mask = (all_labels == cls)
        class_acc[int(cls)] = (all_preds[mask] == all_labels[mask]).mean()

    return overall_acc, balanced_acc, list(map(float, class_acc.values()))

lrs, sgd_accs, class_accs = [], [], []

from tqdm import trange
for epoch in (p := trange(epochs)):
    model.train()
    for Xb, yb in train_loader:
        Xb, yb = Xb.to(device), yb.to(device)
        optimizer.zero_grad()
        loss = criterion(model(Xb), yb)
        loss.backward()
        optimizer.step()

    # update scheduler
    if epoch > swa_start:
        scheduler.step()
        lrs.append(scheduler.get_last_lr())
    else:
        swa_model.update_parameters(model)
        swa_scheduler.step()
        lrs.append(swa_scheduler.get_last_lr())

    overall_acc, balanced_acc, class_acc = evaluate(model, test_loader, device)
    sgd_accs.append(balanced_acc)
    class_accs.append(class_acc)
    class_acc = ','.join(map(lambda x: f'{x:.2%}', class_acc))
    p.set_description(f"Epoch {epoch+1}/{epochs}: {overall_acc:.4f}/{balanced_acc:.4f}/{class_acc}")

torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
swa_overall_acc, swa_balanced_acc, swa_class_acc = evaluate(swa_model, test_loader, device)

print(f"SWA accuracy: {swa_overall_acc:.4f}/{swa_balanced_acc:.4f}/{','.join(map(lambda x: f'{x:.2%}', swa_class_acc))}")
Epoch 60/60: 0.9437/0.8788/93.70%,96.23%,93.92%,74.71%,75.09%,87.47%,94.02%: 100%|██████████| 60/60 [02:30<00:00,  2.51s/it]


SWA accuracy: 0.9329/0.8634/92.26%,95.66%,92.28%,74.52%,71.53%,85.64%,92.50%
plt.plot(sgd_accs, label='SGD')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title("Test Dataset Accuracy")
plt.axhline(y=swa_balanced_acc, linestyle='dashed', color='green', label='SWA')
plt.legend();

png

class_accs = np.array(class_accs)
fig, axes = plt.subplots(figsize=(12, 8), nrows=2, ncols=4, sharex=True, sharey=True)
for i, ax in enumerate(axes.ravel()):
    if i < 7:
        ax.plot(class_accs[:, i])
        ax.axhline(y=swa_class_acc[i], color='green', linestyle='dashed')
        ax.set_title(f'Feature #{i}')
    else:
        ax.axis('off')
fig.tight_layout()

png

In general terms results are not very satisfactory – more sophisticated base models will achieve much higher accuracies. Also, SWA is a bit below the best iterations, so don’t expect to easily improve your model’s performances by using it. The idea, though, is that SWA is more robust and we don’t need to stop convergence at a specific point.