In Part I of this post we have seen how to analyze the Galaxy Zoo dataset and built a simple model; here, we use convolutional neural networks (CNN) for a much more effective one.
import matplotlib.pylab as plt
import numpy as np
import pandas as pd
from pathlib import Path
import random
The data is downloaded from the Kaggle web page, as we did for Part I. The labels are stored in a Pandas dataframe.
IMG_DIR = Path(r'../2022-04-01-galaxy-1/data/images_training_rev1')
labels = pd.read_csv('../2022-04-01-galaxy-1/data/training_solutions_rev1.csv')
print(f"Found {len(labels)} entries with {len(labels.columns)} columns.")
Found 61578 entries with 38 columns.
The datasets of about 60’000 entries is split into training and test sets, with the latter composed by 20% of the entries. What we are splitting are the galaxy IDs – our X_train
and X_test
variables are in effect only a list of IDs, while for the labels we are splitting the actual data.
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = \
train_test_split(labels.GalaxyID, labels[labels.columns[1:]],
test_size=0.20, random_state=42)
print(f"Training set size: {X_train.shape[0]}")
print(f"Testing set size: {X_test.shape[0]}")
Training set size: 49262
Testing set size: 12316
We can now define our custom Dataset
class. This takes in input the list of galaxy IDs to be used and the corresponding labels, and returns images format as well as the corresponding labels in torch.Tensor
. The images are preprocessed:
- first they are cropped at the center, since we have seen in Part I that most of the non-black pixels are in the center;
- they are resized to 69x69;
- they are randomly flipped around the horizontal axis;
- they are randomly rotated around the center;
- they are randomly flipped around the vertical axis;
- they are scaled.
Since random flipping and rotations happen every time an image is fetched, the network never sees the same image twice, thus helping to prevent overfitting. For convenience of plotting the original image, we specify the transformations in input and define the helper function get_trasform()
to help us with that.
import torch
from torchvision import transforms
from torchvision.transforms import Resize, Grayscale, functional
from torchvision.io import read_image
from sklearn.metrics import mean_squared_error
from PIL import Image
class GalaxyDataset(torch.utils.data.Dataset):
def __init__(self, X, Y, path, transform):
self.X = X
self.Y = Y
self.path = path
self.transform = transform
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
img = Image.open(self.path / f'{self.X.iloc[idx]}.jpg')
return self.transform(img), torch.FloatTensor(self.Y.iloc[idx])
def get_transform(crop_and_resize, scale):
if crop_and_resize:
transform = [
transforms.CenterCrop((207, 207)),
transforms.Resize((69, 69)),
]
else:
transform = []
transform += [
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=(0,360)),
transforms.RandomVerticalFlip(p=0.5),
transforms.ToTensor(),
]
if scale:
transform.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
return transforms.Compose(transform)
_ = torch.manual_seed(42)
First, we plot a few original images on the left and the cropped, resized and randomly rotated and flipped images on the right. Indeed most of the relevant features are preserved; we are helping the network by getting rid of useless pixels.
fig, axes = plt.subplots(figsize=(8, 24), nrows=6, ncols=2)
for i in range(6):
img = GalaxyDataset(X_train, Y_train, IMG_DIR, get_transform(False, False))[i][0]
img = img.permute(1, 2, 0)
axes[i, 0].imshow(img)
axes[i, 0].axis('off')
axes[i, 0].text(25, 30, 'Original', color='white')
img = GalaxyDataset(X_train, Y_train, IMG_DIR, get_transform(True, False))[i][0]
img = img.permute(1, 2, 0)
axes[i, 1].imshow(img)
axes[i, 1].axis('off')
axes[i, 1].text(5, 5, 'Cropped, resized, randomly rotated', color='white')
fig.tight_layout()
We are ready to create the CNN model, which is taken from https://github.com/jimsiak/kaggle-galaxies-pytorch and is inspired by the winner solution by Sander Dieleman.
import torch.nn as nn
import torch.nn.functional as F
class SanderDielemanNet(nn.Module):
def __init__(self, num_classes=37):
super(SanderDielemanNet, self).__init__()
# Convolutional and MaxPool layers
self.conv1 = nn.Conv2d(3, 32, 6)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, 5)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(64, 128, 3)
self.conv4 = nn.Conv2d(128, 128, 3)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
# Dense layers
self.fc1 = nn.Linear(128 * 5 * 5, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, num_classes)
def forward(self, x):
batch_size = x.shape[0]
# Convolutional and MaxPool layers
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = self.pool4(x)
# Dense layers
x = x.view(batch_size, -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
return(x)
dataset_train = GalaxyDataset(X_train, Y_train, IMG_DIR, get_transform(True, True))
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=128)
dataset_test = GalaxyDataset(X_test, Y_test, IMG_DIR, get_transform(True, True))
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=128)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import torch.optim as optim
from torch.optim import lr_scheduler
model = SanderDielemanNet(num_classes=37).to(device)
criterion = nn.MSELoss(reduction='sum')
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
import time
losses_train, losses_test = [], []
for epoch in range(1, 101):
start_time = time.time()
loss_train = 0.0
for inputs, labels in dataloader_train:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss_train += loss.item()
loss /= np.product(outputs.shape)
loss.backward()
optimizer.step()
loss_train = np.sqrt(loss_train / len(dataset_train) / 37)
losses_train.append(loss_train)
with torch.no_grad():
loss_test = 0.0
for inputs, labels in dataloader_test:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
loss_test += loss.item()
loss_test = np.sqrt(loss_test / len(dataset_test) / 37)
losses_test.append(loss_test)
elapsed_time = time.time() - start_time
#print(f'Epoch {epoch:4d}: loss train={loss_train:.4f}, loss_test={loss_test:.4f} [{elapsed_time:.2f} s]')
plt.semilogy(losses_train, label='train')
plt.semilogy(losses_test, label='test')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Los Loss');
The model converges nicely, with train and test sets giving similar losses. We can try the model on a few test sets and check the results. To do that, we put together a humna-readable description of the questions, grouped as needed, and look for the highest probably within each group.
groups = [
['Smooth', 'Featured or disc', 'Star or artifact'],
['Edge on', 'Not edge on'],
['Bar through center', 'No bar'],
['Spiral', 'No Spiral'],
['No bulge', 'Just noticeable bulge', 'Obvious bulge', 'Dominant bulge'],
['Odd Feature', 'No Odd Feature'],
['Completely round', 'In between', 'Cigar shaped'],
['Ring (Oddity)', 'Lens or arc (Oddity)', 'Disturbed (Oddity)', 'Irregular (Oddity)',
'Other (Oddity)', 'Merger (Oddity)', 'Dust lane (Oddity)'],
['Rounded bulge', 'Boxy bulge', 'No bulge'],
['Tightly wound arms', 'Medium wound arms', 'Loose wound arms'],
['1 Spiral Arm', '2 Spiral Arms', '3 Spiral Arms', '4 Spiral Arms', 'More than four Spiral Arms', "Can't tell"],
]
for _ in range(10):
i = random.randint(0, len(X_test))
fig, (ax0, ax1) = plt.subplots(figsize=(8, 4), ncols=2)
img = GalaxyDataset(X_test, Y_test, IMG_DIR, get_transform(False, False))[i][0]
img = img.permute(1, 2, 0)
ax0.imshow(img)
ax0.axis('off')
ax0.text(25, 30, f'i={i} Original', color='white')
img = GalaxyDataset(X_test, Y_test, IMG_DIR, get_transform(True, False))[i][0]
img = img.permute(1, 2, 0)
ax1.imshow(img)
ax1.axis('off')
ax1.text(5, 5, 'Cropped, resized, randomly rotated', color='white')
plt.show()
X_i, Y_i = dataset_test[i]
X_i = X_i.to(device)
Y_pred_i = model(X_i.unsqueeze(0)).detach().cpu().squeeze(0)
Y_i = pd.Series(Y_i, index=sum(groups, []))
Y_pred_i = pd.Series(Y_pred_i, index=sum(groups, []))
print('------|----------------------|----------------------|------')
print(' prob | meaning (exact) | meaning (pred) | prob')
print('------|----------------------|----------------------|------')
for group in groups:
S_i, S_pred_i = Y_i[group], Y_pred_i[group]
print(f'{S_i.max():.3f} {S_i.idxmax():>22s} | {S_pred_i.idxmax():22s} {S_pred_i.max():.3f}')
print('\n\n')
------|----------------------|----------------------|------
prob | meaning (exact) | meaning (pred) | prob
------|----------------------|----------------------|------
0.670 Featured or disc | Featured or disc 0.719
0.670 Edge on | Edge on 0.519
0.000 Bar through center | No bar 0.133
0.000 Spiral | Spiral 0.093
0.670 No bulge | No bulge 0.400
0.727 No Odd Feature | No Odd Feature 0.775
0.280 Cigar shaped | Cigar shaped 0.287
0.137 Other (Oddity) | Other (Oddity) 0.097
0.670 No bulge | No bulge 0.400
0.000 Tightly wound arms | Tightly wound arms 0.058
0.000 1 Spiral Arm | Can't tell 0.035
------|----------------------|----------------------|------
prob | meaning (exact) | meaning (pred) | prob
------|----------------------|----------------------|------
0.673 Smooth | Smooth 0.567
0.308 Not edge on | Not edge on 0.336
0.308 No bar | No bar 0.315
0.308 No Spiral | Spiral 0.098
0.183 Just noticeable bulge | Just noticeable bulge 0.155
0.962 No Odd Feature | No Odd Feature 0.967
0.628 In between | In between 0.482
0.038 Ring (Oddity) | Disturbed (Oddity) 0.023
0.000 Rounded bulge | No bulge 0.041
0.000 Tightly wound arms | Tightly wound arms 0.063
0.000 1 Spiral Arm | Can't tell 0.058
------|----------------------|----------------------|------
prob | meaning (exact) | meaning (pred) | prob
------|----------------------|----------------------|------
0.791 Smooth | Smooth 0.829
0.182 Not edge on | Not edge on 0.109
0.182 No bar | No bar 0.090
0.182 No Spiral | Spiral 0.000
0.080 Just noticeable bulge | Obvious bulge 0.058
0.824 No Odd Feature | No Odd Feature 0.895
0.757 In between | In between 0.670
0.059 Disturbed (Oddity) | Other (Oddity) 0.047
0.000 Rounded bulge | Rounded bulge 0.007
0.000 Tightly wound arms | Tightly wound arms 0.011
0.000 1 Spiral Arm | Can't tell 0.001
------|----------------------|----------------------|------
prob | meaning (exact) | meaning (pred) | prob
------|----------------------|----------------------|------
0.652 Featured or disc | Featured or disc 0.752
0.604 Not edge on | Not edge on 0.738
0.604 No bar | No bar 0.680
0.426 Spiral | Spiral 0.554
0.409 Just noticeable bulge | Just noticeable bulge 0.413
0.837 No Odd Feature | No Odd Feature 0.762
0.315 Completely round | Completely round 0.271
0.098 Ring (Oddity) | Disturbed (Oddity) 0.035
0.048 Rounded bulge | No bulge 0.102
0.385 Tightly wound arms | Tightly wound arms 0.432
0.243 Can't tell | Can't tell 0.223
------|----------------------|----------------------|------
prob | meaning (exact) | meaning (pred) | prob
------|----------------------|----------------------|------
0.669 Smooth | Smooth 0.795
0.289 Not edge on | Not edge on 0.161
0.289 No bar | No bar 0.140
0.176 No Spiral | Spiral 0.063
0.163 Dominant bulge | Dominant bulge 0.060
0.854 No Odd Feature | No Odd Feature 0.874
0.612 Completely round | Completely round 0.677
0.094 Other (Oddity) | Irregular (Oddity) 0.033
0.127 No bulge | No bulge 0.057
0.114 Medium wound arms | Tightly wound arms 0.047
0.114 Can't tell | Can't tell 0.041
------|----------------------|----------------------|------
prob | meaning (exact) | meaning (pred) | prob
------|----------------------|----------------------|------
0.876 Smooth | Smooth 0.804
0.073 Not edge on | Not edge on 0.151
0.073 No bar | No bar 0.129
0.049 No Spiral | Spiral 0.020
0.073 Obvious bulge | Obvious bulge 0.070
0.882 No Odd Feature | No Odd Feature 0.894
0.536 Completely round | Completely round 0.585
0.118 Disturbed (Oddity) | Other (Oddity) 0.035
0.000 Rounded bulge | Rounded bulge 0.000
0.024 Medium wound arms | Medium wound arms 0.011
0.024 3 Spiral Arms | Can't tell 0.014
------|----------------------|----------------------|------
prob | meaning (exact) | meaning (pred) | prob
------|----------------------|----------------------|------
0.873 Smooth | Smooth 0.761
0.127 Not edge on | Not edge on 0.190
0.127 No bar | No bar 0.177
0.122 Spiral | Spiral 0.067
0.064 Obvious bulge | Obvious bulge 0.117
0.864 No Odd Feature | No Odd Feature 0.877
0.467 In between | Completely round 0.389
0.082 Irregular (Oddity) | Other (Oddity) 0.043
0.000 Rounded bulge | Rounded bulge 0.000
0.122 Medium wound arms | Medium wound arms 0.035
0.081 Can't tell | Can't tell 0.020
------|----------------------|----------------------|------
prob | meaning (exact) | meaning (pred) | prob
------|----------------------|----------------------|------
0.636 Featured or disc | Smooth 0.616
0.636 Not edge on | Not edge on 0.334
0.636 No bar | No bar 0.329
0.467 No Spiral | Spiral 0.070
0.422 Obvious bulge | Obvious bulge 0.219
0.884 No Odd Feature | No Odd Feature 0.900
0.227 In between | In between 0.466
0.077 Disturbed (Oddity) | Other (Oddity) 0.031
0.000 Rounded bulge | Rounded bulge 0.000
0.169 Medium wound arms | Tightly wound arms 0.023
0.169 Can't tell | Can't tell 0.039
------|----------------------|----------------------|------
prob | meaning (exact) | meaning (pred) | prob
------|----------------------|----------------------|------
0.718 Smooth | Smooth 0.526
0.179 Not edge on | Not edge on 0.291
0.179 No bar | No bar 0.263
0.093 Spiral | Spiral 0.056
0.120 Just noticeable bulge | Obvious bulge 0.133
0.876 No Odd Feature | No Odd Feature 0.809
0.438 Cigar shaped | In between 0.400
0.062 Merger (Oddity) | Other (Oddity) 0.076
0.036 Rounded bulge | Rounded bulge 0.128
0.093 Tightly wound arms | Tightly wound arms 0.031
0.093 Can't tell | Can't tell 0.060
------|----------------------|----------------------|------
prob | meaning (exact) | meaning (pred) | prob
------|----------------------|----------------------|------
0.950 Featured or disc | Featured or disc 0.951
0.950 Not edge on | Not edge on 0.931
0.922 No bar | No bar 0.864
0.671 Spiral | Spiral 0.862
0.448 Just noticeable bulge | Just noticeable bulge 0.541
0.647 No Odd Feature | No Odd Feature 0.870
0.028 In between | Completely round 0.022
0.242 Ring (Oddity) | Disturbed (Oddity) 0.053
0.044 No bulge | No bulge 0.037
0.537 Tightly wound arms | Tightly wound arms 0.481
0.266 Can't tell | 2 Spiral Arms 0.348