Training a Convolutional Neural Network (CNN) in the Browser

Introduction: What is a Convolutional Neural Network?

Convolutional Neural Networks (CNNs) are a class of deep learning models specialized for processing grid-like data, such as images. Introduced in the late 1990s, they have become the standard for image classification, object detection, and other computer vision tasks.

The main advantage of CNNs is their ability to automatically and adaptively learn spatial hierarchies of features from input images.

How to Train a CNN

  1. Prepare the Dataset: Gather and preprocess images, often normalizing and resizing them.
  2. Design the Model: Stack convolutional, pooling, and dense layers to define the network architecture.
  3. Compile the Model: Specify a loss function, optimizer, and metrics for evaluation.
  4. Train: Feed the network batches of images and labels, allowing it to optimize its filters and weights through backpropagation.
  5. Evaluate & Improve: Test the model on unseen data and tune hyperparameters or architecture as needed.

Training involves letting the model learn patterns in the data by minimizing the loss function through iterative updates.

What Does a CNN Look Like?

A CNN is made of several layers, each containing many "neurons" (small computational units). The main types of layers are convolutional, pooling, and dense (fully connected) layers.

Input Image Convolution Layer Pooling Layer Fully Connected
Layer (Output) Prediction

Diagram: Example flow in a simple CNN. Each layer transforms the image into a new set of features, ending in an output prediction.

Easy Example: Recognizing Handwritten Digits

Imagine you want your computer to recognize whether a picture shows the digit “3” or “7”. You show it thousands of examples of hand-drawn numbers.

The network learns by adjusting its inner connections every time it makes a mistake, until it gets very good at telling 3’s and 7’s apart—even if it’s never seen your handwriting before!

TensorFlow.js: CNNs in JavaScript

TensorFlow.js is a JavaScript library for training and running machine learning models in the browser and on Node.js.

With TensorFlow.js, you can experiment and visualize neural networks directly in the browser.

Libraries for Building, Training, and Visualizing Neural Networks

Several JavaScript libraries make it possible to define, train, and visualize neural network models directly in your web browser:

These tools allow you to build, train, and monitor neural networks interactively and visually on the client side.

About the MNIST Dataset and Labels

The MNIST dataset consists of grayscale images of handwritten digits (0-9), commonly used for training and testing image classification models. The images are stored in a PNG file, where each digit is a 28x28 pixel square. Each image has an associated label, which indicates the digit it represents.

Together, these files provide the input data (images) and the ground-truth answers (labels) for training and evaluating neural network models.

Visualizing the Cleaned MNIST Dataset

Once the MNIST dataset is loaded, it can be helpful to visualize some sample images with their labels. This allows you to verify that the data was loaded correctly and to better understand the nature of the data your neural network will process.

Splitting the Dataset: Training and Testing Sets

To evaluate machine learning models fairly, we divide the dataset into two parts: training and testing. The training set is used to teach the model, while the testing set is reserved for checking performance on unseen data. A typical split is 80% training and 20% testing.

Defining a Convolutional Neural Network (CNN)

Convolutional Neural Networks (CNNs) are highly effective for image recognition tasks like MNIST digit classification. Let's define a simple CNN model using TensorFlow.js. Click the button below to create the model and see its layer structure.

What Is a Convolutional Neural Network?

CNN architecture illustration

A CNN processes images step-by-step:

  1. Convolutional Layers: Detect features (like edges and shapes) by sliding filters over the image.
  2. Pooling Layers: Reduce image size while keeping key information, making the network faster and more robust.
  3. Flatten & Dense Layers: Combine detected features to predict the digit class (0–9).
CNNs are the backbone of modern image recognition systems, automatically learning which features matter most for the task.

Image: Wikipedia, Convolutional neural network

Training the CNN Model

Now that we have split our dataset and defined a CNN model, it's time to train the model using the training set. Training means letting the model learn the relationship between input images and their correct labels by adjusting its internal parameters.

  1. Click the Train Model button below to begin training.
  2. The process will run for several epochs (full passes over the training data), showing accuracy and loss for each epoch.
  3. A chart below will update to show your model's progress visually.

Testing (Evaluating) the CNN Model

After training, we need to test our model on unseen data to measure how well it generalizes. This is called model evaluation.
We use the testing set (not used in training) and calculate metrics such as accuracy and loss.

  1. Click the Test Model button below to evaluate the CNN on the test dataset.
  2. The chart will show loss and accuracy for each test batch.
  3. Final results will be displayed once testing is complete.

Understanding the Confusion Matrix

A confusion matrix is a table used to evaluate the performance of a classification model. Each row of the matrix represents the instances of an actual class, while each column represents the instances of a predicted class.

The confusion matrix provides deeper insight into how your model performs on each class, beyond just overall accuracy.