Debug model deep learning via visualisasi saliency map

Published

May 30, 2022

Saya menggunakan PyTorch sebagai pengantar untuk kontrol yang lebih detail. Ada cukup banyak ekuivalen dalam Tensorflow dan Keras di internet.

Tidak dapat dipungkiri bahwa deep learning memiliki kemampuan baik untuk tugas-tugas klasifikasi citra. Namun, deep learning tidak dapat menjelaskan mengapa dan bagaimana cara ia memperoleh suatu output. Dia bisa saja mengklasifikasikan citra kucing dan melabelinya dengan “kucing”, namun ia tidak bisa menjawab “bagaimana kamu tahu bahwa ini gambar kucing?”

Memang sangat sulit untuk membuat model yang dapat dijelaskan. Namun ada satu cara sederhana sebagai upaya untuk menjelaskan hasil klasifikasi suatu model, yaitu dengan visualisasi saliency map. Sederhananya, saliency map adalah bagian (region) dari citra, di mana mata kita meletakkan fokus pertama saat melihat suatu citra. Dalam konteks klasifikasi citra, saliency map menunjukkan area yang menjadi “ciri khas” suatu kelas 1.

Studi kasus: Isyana-Raisa Classifier

Motivating illustration:

Classifier yang sudah dilatih dapat melihat bagian wajah yang menjadi karakteristik wajah Isyana (alis, bibir atas, dan beberapa mole) dan Raisa (pangkal alis, bibir bawah, nostril, dan alae hidung).

Ide saliency map sederhana. Misal kita memiliki suatu citra input (grayscale, 1 channel) I \in \mathbb{R}^{m \times n}. Sebagai contoh, kita punya neural network untuk binary classification sebagai fungsi net(I) dengan aktivasi sigmoid di layer output. Output sigmoid akan kita gunakan sebagai skor klasifikasi. Ambil gradien dari input (dinotasikan dengan \nabla_I), yaitu derivatif dari skor klasifikasi terhadap setiap pixel pada input, sebagai berikut:

\nabla_I = \frac{\partial }{\partial I}net(I),

di mana \nabla_I memiliki ukuran sama dengan I. Berdasarkan magnitude (nilai absolut) dari gradien, kita bisa menentukan saliency map M \in \mathbb{R}^{m \times n} sebagai berikut:

M = \left\lvert \nabla_I \right\rvert.

Ingat kembali bahwa gradien input dapat dilihat sebagai rate of change. Jika nilai gradien suatu pixel pada input besar, maka sedikit saja perubahan nilai yang diterapkan pada pixel tersebut akan semakin mudah mempengaruhi skor klasifikasi. Dengan kata lain, output akan semakin sensitif terhadap perubahan kecil pada pixel dengan nilai gradien yang besar. Dengan konteks saliency map (atau magnitude gradien input), kita bisa menganggap bahwa pixel-pixel dengan nilai saliency besar berkorespondensi dengan lokasi objek (atau ciri khas suatu kelas) pada suatu citra 2.

Implementasi

Dataset

Saya menggunakan API Flickr (versi python) untuk mengambil gambar secara masif. Singkatnya, saya mengunduh gambar dengan kata kunci “Isyana Sarasvati” dan “Raisa Andriana” sebanyak 100 untuk masing-masing kata kunci. Untuk menyederhanakan masalah, saya hanya mengambil bagian wajah saja. Saya membuat script python yang menggunakan OpenCV untuk melakukan cropping pada wajah. Tentu saja kadang ada beberapa hasil crop yang bukan wajah Isyana maupun Raisa. Saya membuangnya secara manual.

Karena hasil dari flickr kurang variatif (gambar monoton dari sedikit event manggung yang sama), saya menambahkan beberapa gambar dari google image. Karena mulai malas, saya melakukan cropping secara manual.

Berikut cuplikan dataset yang saya kumpulkan:

Definisi model (model.py)

Model yang buat simpel saja. Hanya satu layer convolutional dengan aktivasi leaky ReLU diikuti dengan batch normalization dan pooling. Output akan di-flatten dan diteruskan ke 2 layer fully connected, masing-masing dengan aktivasi leaky ReLU dan sigmoid. Saya menggunakan bantuan pytorch lightning untuk memangkas penulisan kode training loop.

import pytorch_lightning as pl
import torch
from torchvision.transforms import transforms


class Classifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.cnn = torch.nn.Sequential(
            torch.nn.Conv2d(3, 16, kernel_size=3, stride=1),
            torch.nn.LeakyReLU(),
            torch.nn.BatchNorm2d(16),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(197136, 64),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(64, 1),
        )

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x).squeeze()
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.BCEWithLogitsLoss()(y_hat, y.float())
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.BCEWithLogitsLoss()(y_hat, y.float())
        return {"val_loss": loss}


transform_fn = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.RandomRotation(15),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
    ]
)

Training

Berikut ini adalah kode training yang saya buat:

import pytorch_lightning as pl
from model import Classifier, transform_fn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms

transform_fn = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.RandomRotation(15),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
    ]
)

training_set = ImageFolder(root="training_set", transform=transform_fn)
training_data = DataLoader(training_set, batch_size=32, shuffle=True)

model = Classifier()
trainer = pl.Trainer(max_epochs=50, accelerator="mps")
trainer.fit(model, training_data)

Saya sediakan pula fungsi transformasi untuk melakukan resizing. Selain itu, saya mengikutkan fungsi rotasi, flipping, dan jittering citra secara random, dengan harapan model dapat melakukan generalisasi lebih baik. Semua transformasi dilakukan tiap kali data pada training_data diakses.

Evaluasi

import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
from matplotlib.colors import Normalize
from model import Classifier
from scipy.ndimage import gaussian_filter


def visualize_image_and_saliency_map(image_file, clf):
    image = PIL.Image.open(image_file)
    image = image.resize((224, 224))
    image = torch.tensor(np.array(image)).float() / 255
    image = image.permute(2, 0, 1)

    # Penting, agar gradien dari citra juga dihitung saat backpropagation
    image.requires_grad_(True)

    # Buat prediksi dan ambil skornya
    pred = torch.sigmoid(clf(image.unsqueeze(0)))

    # Baris ini untuk menghitung gradien dari citra (dan juga semua parameter network)
    pred.backward()

    # Peroleh gradien dari citra
    grad = image.grad.detach().numpy()

    # Ambil nilai maksimum dari nilai absolut semua channel warna,
    # normalisasi ke 0-1, dan lakukan blur via gaussian filter dengan sigma=2
    grad = np.max(np.abs(grad), axis=0)
    grad = Normalize(clip=True)(grad)
    grad = gaussian_filter(grad, sigma=2)

    label = "Raisa" if pred.item() > 0.5 else "Isyana"
    fig, ax = plt.subplots(1, 2, figsize=(8, 4), sharey=True, constrained_layout=True)
    ax[0].imshow(image.detach().numpy().squeeze().transpose(1, 2, 0))
    ax[0].set_title("Citra asli")
    ax[1].imshow(image.detach().numpy().transpose(1, 2, 0))
    im = ax[1].imshow(grad, alpha=0.7)
    ax[1].set_title("Overlay saliency map")
    fig.colorbar(im, ax=ax[1], fraction=0.05, pad=0.04)
    plt.suptitle(f"Prediksi={label}")
    plt.show()


clf = Classifier.load_from_checkpoint(
    "path/to/chekpoint.ckpt"
)
clf.eval()

visualize_image_and_saliency_map("testing_set/i1.jpg", clf)
visualize_image_and_saliency_map("testing_set/i2.jpg", clf)
visualize_image_and_saliency_map("testing_set/i3.jpg", clf)
visualize_image_and_saliency_map("testing_set/i4.jpg", clf)
visualize_image_and_saliency_map("testing_set/r1.jpg", clf)
visualize_image_and_saliency_map("testing_set/r2.jpg", clf)
visualize_image_and_saliency_map("testing_set/r3.jpg", clf)

Hasil

Untuk sekarang, saya tidak terlalu tertarik dengan akurasi yang didapatkan (karena keterbatasan effort untuk memperoleh testing set). Untuk mengecek hasil saliency map, saya mengunduh beberapa citra tambahan dari google image. Berikut ini beberapa hasil yang saya peroleh:

1) Isyana

Dapat dilihat, terdapat satu kesalahan klasifikasi.

2) Raisa

Semuanya berhasil diklasifikasi dengan tepat.

Bias pada dataset

Dataset tidak luput dari bias. Apabila model dilatih secara sub-optimal, kadang-kadang ia akan memberikan saliency map yang “salah fokus”. Perhatikan contoh berikut ini:

Saliency map meng-highlight bagian dari kaca mata dengan cukup kuat. Kita bisa mengartikan bahwa, untuk menentukan dia Isyana atau Raisa, dapat dilihat dari apakah dia menggunakan kacamata atau tidak. Tentu saja ini bukan hal yang kita harapkan. Setelah ditelusuri lebih jauh, ternyata cukup banyak dari sampel wajah Isyana yang menggunakan kacamata. Ini mengakibatkan model mempelajari fokus yang kurang tepat.

Jika dataset terlalu bias terhadap ada/tidaknya kacamata, alih-alih membuat classifier Isyana v.s. Raisa, kita malah tak sengaja membuat classifier pakai kacamata v.s. tidak pakai kacamata. Salah satu solusinya adalah mencari dataset yang lebih banyak dan variatif, termasuk misalnya, mencari sampel Raisa yang menggunakan kacamata.

Kesimpulan

Visualisasi saliency map adalah cara sederhana namun efektif untuk mengetahui bagaimana model “melihat”. Cukup dengan menghitung gradien dari citra, kita dapat memperoleh saliency map yang meng-highlight bagian dari citra yang berkoresponensi dengan skor klasifikasi. Ini cukup bermanfaat dan bisa kita tambahkan ke kotak perkakas debugging kita.

Footnotes

  1. Wikipedia: Saliency_map (https://en.wikipedia.org/wiki/Saliency_map)↩︎

  2. Brahimi, Mohammed, et al. “Deep learning for plant diseases: detection and saliency map visualisation.” Human and machine learning. Springer, Cham, 2018. 93-117. (https://arxiv.org/abs/1312.6034)↩︎