In Part I of this post we have seen how to analyze the Galaxy Zoo dataset and built a simple model; here, we use convolutional neural networks (CNN) for a much more effective one.

import matplotlib.pylab as plt
import numpy as np
import pandas as pd
from pathlib import Path
import random

The data is downloaded from the Kaggle web page, as we did for Part I. The labels are stored in a Pandas dataframe.

IMG_DIR = Path(r'../2022-04-01-galaxy-1/data/images_training_rev1')
labels = pd.read_csv('../2022-04-01-galaxy-1/data/training_solutions_rev1.csv')
print(f"Found {len(labels)} entries with {len(labels.columns)} columns.")
Found 61578 entries with 38 columns.

The datasets of about 60’000 entries is split into training and test sets, with the latter composed by 20% of the entries. What we are splitting are the galaxy IDs – our X_train and X_test variables are in effect only a list of IDs, while for the labels we are splitting the actual data.

from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = \
    train_test_split(labels.GalaxyID, labels[labels.columns[1:]],  
                     test_size=0.20, random_state=42)
print(f"Training set size: {X_train.shape[0]}")
print(f"Testing set size: {X_test.shape[0]}")
Training set size: 49262
Testing set size: 12316

We can now define our custom Dataset class. This takes in input the list of galaxy IDs to be used and the corresponding labels, and returns images format as well as the corresponding labels in torch.Tensor. The images are preprocessed:

  1. first they are cropped at the center, since we have seen in Part I that most of the non-black pixels are in the center;
  2. they are resized to 69x69;
  3. they are randomly flipped around the horizontal axis;
  4. they are randomly rotated around the center;
  5. they are randomly flipped around the vertical axis;
  6. they are scaled.

Since random flipping and rotations happen every time an image is fetched, the network never sees the same image twice, thus helping to prevent overfitting. For convenience of plotting the original image, we specify the transformations in input and define the helper function get_trasform() to help us with that.

import torch
from torchvision import transforms
from torchvision.transforms import Resize, Grayscale, functional
from torchvision.io import read_image
from sklearn.metrics import mean_squared_error
from PIL import Image
class GalaxyDataset(torch.utils.data.Dataset):

    def __init__(self, X, Y, path, transform):
        self.X = X
        self.Y = Y
        self.path = path
        self.transform = transform

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        img = Image.open(self.path / f'{self.X.iloc[idx]}.jpg')
        return self.transform(img), torch.FloatTensor(self.Y.iloc[idx])
def get_transform(crop_and_resize, scale):
	if crop_and_resize:
		transform = [
			transforms.CenterCrop((207, 207)),
			transforms.Resize((69, 69)),
		]
	else:
		transform = []
	transform += [
		transforms.RandomHorizontalFlip(p=0.5),
		transforms.RandomRotation(degrees=(0,360)),
		transforms.RandomVerticalFlip(p=0.5),
		transforms.ToTensor(),
	]
	if scale:
		transform.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
	return transforms.Compose(transform)
_ = torch.manual_seed(42)

First, we plot a few original images on the left and the cropped, resized and randomly rotated and flipped images on the right. Indeed most of the relevant features are preserved; we are helping the network by getting rid of useless pixels.

fig, axes = plt.subplots(figsize=(8, 24), nrows=6, ncols=2)
for i in range(6):
    img = GalaxyDataset(X_train, Y_train, IMG_DIR, get_transform(False, False))[i][0]
    img = img.permute(1, 2, 0)
    axes[i, 0].imshow(img)
    axes[i, 0].axis('off')
    axes[i, 0].text(25, 30, 'Original', color='white')
    img = GalaxyDataset(X_train, Y_train, IMG_DIR, get_transform(True, False))[i][0]
    img = img.permute(1, 2, 0)
    axes[i, 1].imshow(img)
    axes[i, 1].axis('off')
    axes[i, 1].text(5, 5, 'Cropped, resized, randomly rotated', color='white')
fig.tight_layout()

png

We are ready to create the CNN model, which is taken from https://github.com/jimsiak/kaggle-galaxies-pytorch and is inspired by the winner solution by Sander Dieleman.

import torch.nn as nn
import torch.nn.functional as F

class SanderDielemanNet(nn.Module): 
	def __init__(self, num_classes=37): 
		super(SanderDielemanNet, self).__init__() 
		# Convolutional and MaxPool layers 
		self.conv1 = nn.Conv2d(3, 32, 6)
		self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
		self.conv2 = nn.Conv2d(32, 64, 5)
		self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
		self.conv3 = nn.Conv2d(64, 128, 3)
		self.conv4 = nn.Conv2d(128, 128, 3)
		self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
		# Dense layers
		self.fc1 = nn.Linear(128 * 5 * 5, 128)
		self.fc2 = nn.Linear(128, 64)
		self.fc3 = nn.Linear(64, num_classes)

	def forward(self, x):
		batch_size = x.shape[0]
		# Convolutional and MaxPool layers 
		x = F.relu(self.conv1(x)) 
		x = self.pool1(x) 
		x = F.relu(self.conv2(x)) 
		x = self.pool2(x) 
		x = F.relu(self.conv3(x)) 
		x = F.relu(self.conv4(x)) 
		x = self.pool4(x)
		# Dense layers 
		x = x.view(batch_size, -1)
		x = F.relu(self.fc1(x)) 
		x = F.relu(self.fc2(x))
		x = F.relu(self.fc3(x)) 
		return(x)
dataset_train = GalaxyDataset(X_train, Y_train, IMG_DIR, get_transform(True, True))
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=128)
dataset_test = GalaxyDataset(X_test, Y_test, IMG_DIR, get_transform(True, True))
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=128)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import torch.optim as optim
from torch.optim import lr_scheduler
model = SanderDielemanNet(num_classes=37).to(device)
criterion = nn.MSELoss(reduction='sum')
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
import time
losses_train, losses_test = [], []
for epoch in range(1, 101):
    start_time = time.time()
    loss_train = 0.0
    for inputs, labels in dataloader_train:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss_train += loss.item()
        loss /= np.product(outputs.shape)
        loss.backward()
        optimizer.step()
    loss_train = np.sqrt(loss_train / len(dataset_train) / 37)
    losses_train.append(loss_train)

    with torch.no_grad():
        loss_test = 0.0
        for inputs, labels in dataloader_test:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss_test += loss.item()
        loss_test = np.sqrt(loss_test / len(dataset_test) / 37)
        losses_test.append(loss_test)
    elapsed_time = time.time() - start_time
    #print(f'Epoch {epoch:4d}: loss train={loss_train:.4f}, loss_test={loss_test:.4f} [{elapsed_time:.2f} s]')
plt.semilogy(losses_train, label='train')
plt.semilogy(losses_test, label='test')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Los Loss');

png

The model converges nicely, with train and test sets giving similar losses. We can try the model on a few test sets and check the results. To do that, we put together a humna-readable description of the questions, grouped as needed, and look for the highest probably within each group.

groups = [
    ['Smooth', 'Featured or disc', 'Star or artifact'],
    ['Edge on', 'Not edge on'],
    ['Bar through center', 'No bar'],
    ['Spiral', 'No Spiral'],
    ['No bulge', 'Just noticeable bulge', 'Obvious bulge', 'Dominant bulge'],
    ['Odd Feature', 'No Odd Feature'],
    ['Completely round', 'In between', 'Cigar shaped'],
    ['Ring (Oddity)', 'Lens or arc (Oddity)', 'Disturbed (Oddity)', 'Irregular (Oddity)',
        'Other (Oddity)', 'Merger (Oddity)', 'Dust lane (Oddity)'],
    ['Rounded bulge', 'Boxy bulge', 'No bulge'],
    ['Tightly wound arms', 'Medium wound arms', 'Loose wound arms'],
    ['1 Spiral Arm', '2 Spiral Arms', '3 Spiral Arms', '4 Spiral Arms', 'More than four Spiral Arms', "Can't tell"],
]
for _ in range(10):
    i = random.randint(0, len(X_test))

    fig, (ax0, ax1) = plt.subplots(figsize=(8, 4), ncols=2)
    img = GalaxyDataset(X_test, Y_test, IMG_DIR, get_transform(False, False))[i][0]
    img = img.permute(1, 2, 0)
    ax0.imshow(img)
    ax0.axis('off')
    ax0.text(25, 30, f'i={i} Original', color='white')
    img = GalaxyDataset(X_test, Y_test, IMG_DIR, get_transform(True, False))[i][0]
    img = img.permute(1, 2, 0)
    ax1.imshow(img)
    ax1.axis('off')
    ax1.text(5, 5, 'Cropped, resized, randomly rotated', color='white')
    plt.show()

    X_i, Y_i = dataset_test[i]
    X_i = X_i.to(device)
    Y_pred_i = model(X_i.unsqueeze(0)).detach().cpu().squeeze(0)

    Y_i = pd.Series(Y_i, index=sum(groups, []))
    Y_pred_i = pd.Series(Y_pred_i, index=sum(groups, []))

    print('------|----------------------|----------------------|------')
    print(' prob |      meaning (exact) | meaning (pred)       | prob')
    print('------|----------------------|----------------------|------')
    for group in groups:
        S_i, S_pred_i = Y_i[group], Y_pred_i[group]
        print(f'{S_i.max():.3f} {S_i.idxmax():>22s} | {S_pred_i.idxmax():22s} {S_pred_i.max():.3f}')
    
    print('\n\n')

png

------|----------------------|----------------------|------
 prob |      meaning (exact) | meaning (pred)       | prob
------|----------------------|----------------------|------
0.670       Featured or disc | Featured or disc       0.719
0.670                Edge on | Edge on                0.519
0.000     Bar through center | No bar                 0.133
0.000                 Spiral | Spiral                 0.093
0.670               No bulge | No bulge               0.400
0.727         No Odd Feature | No Odd Feature         0.775
0.280           Cigar shaped | Cigar shaped           0.287
0.137         Other (Oddity) | Other (Oddity)         0.097
0.670               No bulge | No bulge               0.400
0.000     Tightly wound arms | Tightly wound arms     0.058
0.000           1 Spiral Arm | Can't tell             0.035

png

------|----------------------|----------------------|------
 prob |      meaning (exact) | meaning (pred)       | prob
------|----------------------|----------------------|------
0.673                 Smooth | Smooth                 0.567
0.308            Not edge on | Not edge on            0.336
0.308                 No bar | No bar                 0.315
0.308              No Spiral | Spiral                 0.098
0.183  Just noticeable bulge | Just noticeable bulge  0.155
0.962         No Odd Feature | No Odd Feature         0.967
0.628             In between | In between             0.482
0.038          Ring (Oddity) | Disturbed (Oddity)     0.023
0.000          Rounded bulge | No bulge               0.041
0.000     Tightly wound arms | Tightly wound arms     0.063
0.000           1 Spiral Arm | Can't tell             0.058

png

------|----------------------|----------------------|------
 prob |      meaning (exact) | meaning (pred)       | prob
------|----------------------|----------------------|------
0.791                 Smooth | Smooth                 0.829
0.182            Not edge on | Not edge on            0.109
0.182                 No bar | No bar                 0.090
0.182              No Spiral | Spiral                 0.000
0.080  Just noticeable bulge | Obvious bulge          0.058
0.824         No Odd Feature | No Odd Feature         0.895
0.757             In between | In between             0.670
0.059     Disturbed (Oddity) | Other (Oddity)         0.047
0.000          Rounded bulge | Rounded bulge          0.007
0.000     Tightly wound arms | Tightly wound arms     0.011
0.000           1 Spiral Arm | Can't tell             0.001

png

------|----------------------|----------------------|------
 prob |      meaning (exact) | meaning (pred)       | prob
------|----------------------|----------------------|------
0.652       Featured or disc | Featured or disc       0.752
0.604            Not edge on | Not edge on            0.738
0.604                 No bar | No bar                 0.680
0.426                 Spiral | Spiral                 0.554
0.409  Just noticeable bulge | Just noticeable bulge  0.413
0.837         No Odd Feature | No Odd Feature         0.762
0.315       Completely round | Completely round       0.271
0.098          Ring (Oddity) | Disturbed (Oddity)     0.035
0.048          Rounded bulge | No bulge               0.102
0.385     Tightly wound arms | Tightly wound arms     0.432
0.243             Can't tell | Can't tell             0.223

png

------|----------------------|----------------------|------
 prob |      meaning (exact) | meaning (pred)       | prob
------|----------------------|----------------------|------
0.669                 Smooth | Smooth                 0.795
0.289            Not edge on | Not edge on            0.161
0.289                 No bar | No bar                 0.140
0.176              No Spiral | Spiral                 0.063
0.163         Dominant bulge | Dominant bulge         0.060
0.854         No Odd Feature | No Odd Feature         0.874
0.612       Completely round | Completely round       0.677
0.094         Other (Oddity) | Irregular (Oddity)     0.033
0.127               No bulge | No bulge               0.057
0.114      Medium wound arms | Tightly wound arms     0.047
0.114             Can't tell | Can't tell             0.041

png

------|----------------------|----------------------|------
 prob |      meaning (exact) | meaning (pred)       | prob
------|----------------------|----------------------|------
0.876                 Smooth | Smooth                 0.804
0.073            Not edge on | Not edge on            0.151
0.073                 No bar | No bar                 0.129
0.049              No Spiral | Spiral                 0.020
0.073          Obvious bulge | Obvious bulge          0.070
0.882         No Odd Feature | No Odd Feature         0.894
0.536       Completely round | Completely round       0.585
0.118     Disturbed (Oddity) | Other (Oddity)         0.035
0.000          Rounded bulge | Rounded bulge          0.000
0.024      Medium wound arms | Medium wound arms      0.011
0.024          3 Spiral Arms | Can't tell             0.014

png

------|----------------------|----------------------|------
 prob |      meaning (exact) | meaning (pred)       | prob
------|----------------------|----------------------|------
0.873                 Smooth | Smooth                 0.761
0.127            Not edge on | Not edge on            0.190
0.127                 No bar | No bar                 0.177
0.122                 Spiral | Spiral                 0.067
0.064          Obvious bulge | Obvious bulge          0.117
0.864         No Odd Feature | No Odd Feature         0.877
0.467             In between | Completely round       0.389
0.082     Irregular (Oddity) | Other (Oddity)         0.043
0.000          Rounded bulge | Rounded bulge          0.000
0.122      Medium wound arms | Medium wound arms      0.035
0.081             Can't tell | Can't tell             0.020

png

------|----------------------|----------------------|------
 prob |      meaning (exact) | meaning (pred)       | prob
------|----------------------|----------------------|------
0.636       Featured or disc | Smooth                 0.616
0.636            Not edge on | Not edge on            0.334
0.636                 No bar | No bar                 0.329
0.467              No Spiral | Spiral                 0.070
0.422          Obvious bulge | Obvious bulge          0.219
0.884         No Odd Feature | No Odd Feature         0.900
0.227             In between | In between             0.466
0.077     Disturbed (Oddity) | Other (Oddity)         0.031
0.000          Rounded bulge | Rounded bulge          0.000
0.169      Medium wound arms | Tightly wound arms     0.023
0.169             Can't tell | Can't tell             0.039

png

------|----------------------|----------------------|------
 prob |      meaning (exact) | meaning (pred)       | prob
------|----------------------|----------------------|------
0.718                 Smooth | Smooth                 0.526
0.179            Not edge on | Not edge on            0.291
0.179                 No bar | No bar                 0.263
0.093                 Spiral | Spiral                 0.056
0.120  Just noticeable bulge | Obvious bulge          0.133
0.876         No Odd Feature | No Odd Feature         0.809
0.438           Cigar shaped | In between             0.400
0.062        Merger (Oddity) | Other (Oddity)         0.076
0.036          Rounded bulge | Rounded bulge          0.128
0.093     Tightly wound arms | Tightly wound arms     0.031
0.093             Can't tell | Can't tell             0.060

png

------|----------------------|----------------------|------
 prob |      meaning (exact) | meaning (pred)       | prob
------|----------------------|----------------------|------
0.950       Featured or disc | Featured or disc       0.951
0.950            Not edge on | Not edge on            0.931
0.922                 No bar | No bar                 0.864
0.671                 Spiral | Spiral                 0.862
0.448  Just noticeable bulge | Just noticeable bulge  0.541
0.647         No Odd Feature | No Odd Feature         0.870
0.028             In between | Completely round       0.022
0.242          Ring (Oddity) | Disturbed (Oddity)     0.053
0.044               No bulge | No bulge               0.037
0.537     Tightly wound arms | Tightly wound arms     0.481
0.266             Can't tell | 2 Spiral Arms          0.348