In this article we look at another application of autoencoders: denoising. We use the small dataset of the denoising dirty documents Kaggle competition and show how to make it work.

Denoising is a process in which an item, which is an image in this case, contains some ‘noise’, that is some unwanted and unessential feature. Image coffee stains on a paper, or some shades causes by the aging of the document itself, or signs coming from the printer: the goal is to remove them and obtain the original, not-dirty image.

import random
from pathlib import Path
import torch
from torchvision import datasets, transforms, models
import torch.nn.functional as F
from torch import nn
from torch import optim
from collections import OrderedDict

from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
random.seed(42)
torch.manual_seed(43);

The train dataset contains two images for each entry, one of which is ‘clean’ and the other is ‘dirty’. As we will see, all the images are quite similar, rendering the same text with different fonts families and font sizes. There are only a few images: 144 in the training dataset and 72 in the test, but since they are all quite similar it will suffice.

data_dir = Path('./data')
train_dir = data_dir / 'train'
train_cleaned_dir = data_dir / 'train_cleaned'
test_dir = data_dir / 'test'

train_images = sorted(train_dir.glob('*.png'))
train_cleaned_images = sorted(train_cleaned_dir.glob('*.png'))
test_images = sorted(test_dir.glob('*.png'))

print('Number of Images in train:', len(train_images))
print('Number of Images in train_cleaned:', len(train_cleaned_images))
print('Number of Images in test:', len(test_images))
Number of Images in train: 144
Number of Images in train_cleaned: 144
Number of Images in test: 72
transform = transforms.Compose([
    transforms.Resize((320, 480)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

X = []
for image in train_images:
    pil_image = Image.open(image)
    pil_image = transform(pil_image)
    X.append(pil_image)
    
Y = []
for image in train_cleaned_images:
    pil_image = Image.open(image)
    pil_image = transform(pil_image)
    Y.append(pil_image)
    
test_images_transformed = []
for image in test_images:
    pil_image = Image.open(image)
    pil_image = transform(pil_image)
    test_images_transformed.append(pil_image)
def imshow(image, ax, title):
    if ax is None:
        fig, ax = plt.subplots()
    # PyTorch tensors assume the color channel is the first dimension
    # but matplotlib assumes is the third dimension
    image = image.transpose((1, 2, 0))
    
    # undo preprocessing
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    image = std * image + mean
    
    # image needs to be clipped between 0 and 1 or it looks like noise when displayed
    image = np.clip(image, 0, 1)
    
    ax.imshow(image)
    ax.grid(False)
    ax.set_title(title)
dataset = [(X[i], Y[i]) for i in range(len(X))]
random.shuffle(dataset)

split_size = 0.9
index = int(len(dataset)*split_size)

train_dataset = dataset[:index]
valid_dataset = dataset[index:]

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32)
images, targets = next(iter(train_loader))
images, targets = images.numpy(), targets.numpy()

def show_pair(i):
    plt.figure(figsize=(12, 14))
    ax = plt.subplot(1, 2, 1)
    imshow(images[i], ax, 'Original Image')
    ax = plt.subplot(1, 2, 2)
    imshow(targets[i], ax, 'Denoised Image')
for i in range(20):
    show_pair(i)
    plt.show()

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

The autoencoder itself is quite classical, with two layers of convolutional neural networks in the encoder and three such layers in the decoder. A max pool operator is used to reduce the image dimensions by half in each step, and an interpolator to extend the image dimensions in the decoder. The structure of the two is symmetric, as usual.

class DenoiserAutoencoder(nn.Module):
    
    def __init__(self):
        super().__init__();
        self.encoder1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.encoder2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.decoder1 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.decoder2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.decoder3 = nn.Conv2d(32, 1, kernel_size=3, padding=1)
        
    def forward(self, x):
        # encoding
        x = F.relu(self.encoder1(x))
        x = self.pool(x)
        x = F.relu(self.encoder2(x))
        x = self.pool(x)
        # decoding
        x = F.relu(self.decoder1(x))
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = F.relu(self.decoder2(x))
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = torch.sigmoid(self.decoder3(x))
        return x
model = DenoiserAutoencoder()
print(model)
DenoiserAutoencoder(
  (encoder1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (decoder1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder2): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder3): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.has_mps:
    device = 'mps'
model = model.to(device)
print(f"Using device '{device}'")
Using device 'cpu'
def train(model, train_loader, valid_loader, num_epochs):
    
    for epoch in range(num_epochs):
        training_loss = 0.0        
        for images, targets in train_loader:            
            images, targets = images.to(device), targets.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            training_loss += loss.item()
        
        with torch.no_grad():
            valid_loss = 0
            for images, targets in valid_loader:
                images, targets = images.to(device), targets.to(device)
            
                outputs = model(images)
                loss = criterion(outputs, targets)
                
                valid_loss += loss.item()
        
        print(f'Epoch: {epoch + 1: 2d}/{num_epochs}    Training Loss: {training_loss/len(train_loader):.3f}    ' \
              f'Testing Loss: {valid_loss/len(valid_loader):.3f}')
train(model, train_loader, valid_loader, 50)
torch.save(model.state_dict(), './model.pt')
Epoch: 1/50    Training Loss: 0.273    Testing Loss: 0.228
Epoch: 2/50    Training Loss: 0.228    Testing Loss: 0.214
Epoch: 3/50    Training Loss: 0.208    Testing Loss: 0.195
Epoch: 4/50    Training Loss: 0.182    Testing Loss: 0.167
Epoch: 5/50    Training Loss: 0.159    Testing Loss: 0.150
Epoch: 6/50    Training Loss: 0.147    Testing Loss: 0.141
Epoch: 7/50    Training Loss: 0.138    Testing Loss: 0.134
Epoch: 8/50    Training Loss: 0.131    Testing Loss: 0.124
Epoch: 9/50    Training Loss: 0.124    Testing Loss: 0.121
Epoch: 10/50    Training Loss: 0.119    Testing Loss: 0.117
Epoch: 11/50    Training Loss: 0.116    Testing Loss: 0.113
Epoch: 12/50    Training Loss: 0.110    Testing Loss: 0.110
Epoch: 13/50    Training Loss: 0.107    Testing Loss: 0.106
Epoch: 14/50    Training Loss: 0.104    Testing Loss: 0.103
Epoch: 15/50    Training Loss: 0.101    Testing Loss: 0.100
Epoch: 16/50    Training Loss: 0.098    Testing Loss: 0.098
Epoch: 17/50    Training Loss: 0.095    Testing Loss: 0.096
Epoch: 18/50    Training Loss: 0.093    Testing Loss: 0.094
Epoch: 19/50    Training Loss: 0.091    Testing Loss: 0.092
Epoch: 20/50    Training Loss: 0.090    Testing Loss: 0.091
Epoch: 21/50    Training Loss: 0.088    Testing Loss: 0.089
Epoch: 22/50    Training Loss: 0.087    Testing Loss: 0.088
Epoch: 23/50    Training Loss: 0.085    Testing Loss: 0.086
Epoch: 24/50    Training Loss: 0.084    Testing Loss: 0.085
Epoch: 25/50    Training Loss: 0.082    Testing Loss: 0.083
Epoch: 26/50    Training Loss: 0.081    Testing Loss: 0.082
Epoch: 27/50    Training Loss: 0.079    Testing Loss: 0.081
Epoch: 28/50    Training Loss: 0.078    Testing Loss: 0.080
Epoch: 29/50    Training Loss: 0.077    Testing Loss: 0.078
Epoch: 30/50    Training Loss: 0.076    Testing Loss: 0.077
Epoch: 31/50    Training Loss: 0.075    Testing Loss: 0.076
Epoch: 32/50    Training Loss: 0.074    Testing Loss: 0.075
Epoch: 33/50    Training Loss: 0.073    Testing Loss: 0.075
Epoch: 34/50    Training Loss: 0.072    Testing Loss: 0.074
Epoch: 35/50    Training Loss: 0.072    Testing Loss: 0.074
Epoch: 36/50    Training Loss: 0.071    Testing Loss: 0.074
Epoch: 37/50    Training Loss: 0.071    Testing Loss: 0.075
Epoch: 38/50    Training Loss: 0.071    Testing Loss: 0.079
Epoch: 39/50    Training Loss: 0.072    Testing Loss: 0.083
Epoch: 40/50    Training Loss: 0.076    Testing Loss: 0.072
Epoch: 41/50    Training Loss: 0.074    Testing Loss: 0.073
Epoch: 42/50    Training Loss: 0.070    Testing Loss: 0.069
Epoch: 43/50    Training Loss: 0.068    Testing Loss: 0.070
Epoch: 44/50    Training Loss: 0.068    Testing Loss: 0.068
Epoch: 45/50    Training Loss: 0.067    Testing Loss: 0.067
Epoch: 46/50    Training Loss: 0.066    Testing Loss: 0.067
Epoch: 47/50    Training Loss: 0.065    Testing Loss: 0.066
Epoch: 48/50    Training Loss: 0.065    Testing Loss: 0.065
Epoch: 49/50    Training Loss: 0.064    Testing Loss: 0.065
Epoch: 50/50    Training Loss: 0.064    Testing Loss: 0.064
model.load_state_dict(torch.load('./model.pt'))
<All keys matched successfully>

To analyze the performances, first we check the denoising on the training dataset, since we have the original clean image and the corresponding dirty one. Results are generally good, albeit sometimes the denoised image is too blurry.

def plot_triplet(n):
    fig, (ax0, ax1, ax2) = plt.subplots(figsize=(14, 8), ncols=3)
    image, image_clean = X[n], Y[n]
    image_pred = model(image.unsqueeze(0).to(device)).cpu().detach().numpy()
    imshow(image.numpy(), ax0, 'Original Dirty Image')
    imshow(image_clean.numpy(), ax1, 'Original Clean Image')
    imshow(image_pred.squeeze(0), ax2, 'Denoised Image')
    fig.tight_layout()
for n in random.sample(range(len(X)), 10):
    plot_triplet(n)

png

png

png

png

png

png

png

png

png

png

For the test dataset we only have the original (dirty image), so we plot that and the denoised one. When the noise is low-frequency compared with the text, results are excellent; instead when the text and the noise have similar frequencies, results are less good. Given the small size of the dataset, though, we can be quite satisfied by the performances of our autoencoder.

for n in random.sample(range(len(test_images)), 10):
    image = test_images_transformed[n]
    image = image.unsqueeze(0).to(device)
    output = model(image)

    image, output = image.detach().cpu().numpy(), output.detach().cpu().numpy()

    plt.figure(figsize=(12,14))
    ax = plt.subplot(1,2,1)
    imshow(image[0], ax, 'Original Image')

    ax = plt.subplot(1,2,2)
    imshow(output[0], ax, 'Denoised Image')
    
    plt.show()

png

png

png

png

png

png

png

png

png

png