This page is part of the documentation for the Machine Learning Database.

It is a static snapshot of a Notebook which you can play with interactively by trying MLDB online now.
It's free and takes 30 seconds to get going.

MLPaint: Real-Time Handwritten Digits Recognizer

The automatic recognition of handwritten digits is now a well understood and studied Machine Vision and Machine Learning problem. We will be using MNIST (check out Wikipedia's page on MNIST) to train our models. From the description on Yann LeCun's MNIST database of handwriten digits:

The MNIST database of handwritten digits, available from this page, has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting.

To learn more, you can also check out Kaggle's Digit Recognizer page.

In this demo will use MLDB's classifiers functions and REST API, to create a plug-in to predict the value of handwritten digits. We will also use MLDB's explain functions to visually represent the predictive "value or importance" of each pixel in our final predictions.

Check out the video below for a demo of what we'll be creating:

In [14]:
from IPython.display import YouTubeVideo

The notebook cells below use pymldb's Connection class to make REST API calls. You can check out the Using pymldb Tutorial for more details.

In [2]:
from pymldb import Connection
mldb = Connection()

... and other Python librairies:

In [3]:
import random
import math
import numpy as np
from pandas import DataFrame
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

from IPython.display import display
from ipywidgets import widgets

Loading the data

A pickled version of the dataset is available on the website.

The dataset has been unpickled and saved in a public Amazon's S3 cloud storage. Check out MLDB's Protocol Handlers for Files and URLS.

In [ ]:
data_url_mnist = ''

print mldb.put('/v1/procedures/import_digits_mnist', {
    "params": {
        "dataFileUrl": data_url_mnist,
        "outputDataset": "digits_mnist",
        "select": "{* EXCLUDING(\"785\")} AS *, \"785\" AS label",
        "runOnCreation": True,

In the original MNIST datasets, the features and labels were in two seperate datasets. To make it easier, we joined the features with the labels. Column 785 is the labels column which was renamed accordingly 'label'.

Let's explore the data

See the Query API documentation for more details on SQL queries.

In [5]:
data_stats = mldb.query("""
        avg(horizontal_count({* EXCLUDING(label)})) as NoOfFeatures,
        count(label) AS TestExamples
    FROM digits_mnist 
print data_stats
          NoOfFeatures  TestExamples
[]                 784         60000

Each example is a row made up of 784 pixels or features. By reshaping the 1D data into a 2D representation, we can visualize the data a little better. At each refresh, we get randomly selected rows for each label using the sample function in a SQL From Expression.

In [ ]:
labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

x_array1 = []
x_array2 = []
x_data = []
sq = int(np.sqrt(data_stats['NoOfFeatures'][0]))

for label in labels:

    x_data = mldb.query("""
        SELECT * EXCLUDING(label) 
        FROM sample(
                (select * FROM digits_mnist WHERE label = %d AND rowHash() %% 5 = 0),
                {rows: 1}
    """ %label)
    if label < 5:
        x_array1.extend(x_data.as_matrix().reshape(sq, sq))
    if label >= 5:
        x_array2.extend(x_data.as_matrix().reshape(sq, sq))    

f, (fig1, fig2) = plt.subplots(1, 2, sharey=True)

Training a classifier

We will create a Procedure of type classifier.experiment to train and test our model. The configuration parameter defines a Random Forest algorithm. The model make take some time to train...

In [7]:
conf_algo =  {
    "bbdt_d5": {
        "type": "bagging",
        "verbosity": 3,
        "weak_learner": {
            "type": "boosting",
            "verbosity": 3,
            "weak_learner": {
                "type": "decision_tree",
                "max_depth": 7,
                "verbosity": 0,
                "update_alg": "gentle",
                "random_feature_propn": 0.3
            "min_iter": 5,
            "max_iter": 30
        "num_bags": 5

conf_class = {
    "type": "classifier.experiment",
    "params": {
        "experimentName": "mnist_model",
        "mode": "categorical",
        "inputData" : """
                {* EXCLUDING(label*)} AS features, 
                label AS label
            FROM digits_mnist
        "datasetFolds": [
                "trainingWhere": "rowHash() % 5 != 0", #80% of total data
                "testingWhere": "rowHash() % 5 = 0"    #20% of total data
        "algorithm": "bbdt_d5",
        "configuration": conf_algo,
        "modelFileUrlPattern": "file://models/mnist_model.cls",
        "keepArtifacts": True,
        "outputAccuracyDataset": True,   
        "runOnCreation": True,
        "evalTrain": True

results = mldb.put("/v1/procedures/mnist_model", conf_class)
In [8]:
accuracy = results.json()['status']['firstRun']['status']['aggregatedTest']['weightedStatistics']['accuracy']['mean']
print "\nModel classification accuracy on test set = %0.4f\n" % accuracy
Model classification accuracy on test set = 0.9944

We are now going to construct the confusion matrix from results on the test set using the pivot aggregate function. You can learn more about confusion matrices here.

In [9]:
confusionMatrix = results.json()['status']['firstRun']['status']['folds'][0]['resultsTest']['confusionMatrix']
confusionMatrix = pd.DataFrame(confusionMatrix).pivot_table(index="predicted", columns="actual")

df = np.log(confusionMatrix)
df = df.fillna(0)

fig = plt.figure(figsize=(8, 8))
plt.imshow(df, interpolation='nearest',
plt.yticks(np.arange(0, 10, 1), fontsize=14)
plt.ylabel("Predicted", fontsize=16)
plt.xticks(np.arange(0, 10, 1), fontsize=14)
plt.xlabel("Actual", fontsize=16)
<matplotlib.text.Text at 0x7fb85a516750>

The model seems to be doing a pretty good job at classfication as seen with the confusion matrix above. In a small percentage of cases, the model seems to think that a '4' is a '9' and that a '3' is a '2'. The two sets of digits are similar in the concentration of pixels, so this makes some sense.

How does the model make its predictions?

The 'explain' function provides each pixel's weight on the final outcome.

Let's create a function of type classifier.explain to help us understand what's happening here.

In [10]:
print mldb.put('/v1/functions/mnist_explainer', { 
    "id": "mnist_explainer", 
    "type": "classifier.explain",
    "params": { "modelFileUrl": "file://models/mnist_model.cls" }
<Response [201]>

We automatically get a REST API to test our model with individual digits

The procedure above created for us a Function of type classifier. We will be using two functions:

  • The scorer function: Scores aren't probabilities, but they can be used to create binary classifiers by applying a cutoff threshold. MLDB's classifier.experiment procedure that we have seen previously outputs a score for each digit (even the wrong ones).
  • The explain function: The explain function shows how each pixel and its value (black or white) of an image contributes to the model's prediction. We colored such pixels in green for positive contributions and red for negative contributions in Figure 2 below. In essense, pixels flagged in red in the explain figure should be changed to get a better score. For example, a white-colored pixel, that was seen frequently for digit '5' in the train set, will be flagged green if it seen for digit '5' in the test set. If the same pixel is actually of a different color for digit '5' in the test set, then the pixel will be flagged red.

Note that the digits are from the test set - we used 80% of the data for training and 20% for testing. You can also get the same digit written differently by using the offset bar. We using an SQL offset - we are calling it 'other_example' in the code below. The offset specifies the number of rows to skip before returning values from the query expression.

In [11]:
def rgb_explain(x):
    scale = 5
    explain = [[0, 0, 0]] * data_stats['NoOfFeatures'][0] # [R,G,B] for color image
    number_explain = len(x)
    for j, col in enumerate(x.columns.values):
            index = int(col)
            val = x.values[0][j] * scale
            if (val >= 0.2):
                explain[index] = [0, val, 0] # make it green
            elif (val <= -0.2):
                explain[index] = [- val, 0, 0] # make it red
    return np.array(explain).reshape(sq, sq, 3)

def test_img_plot(digit = [0,9], other_example=[0,1000]):
    data = mldb.query("""
            * EXCLUDING(label),
            mnist_model_scorer_0({ features: {* EXCLUDING(label*)} })[scores] AS score
        FROM digits_mnist
        WHERE label = %(digit)d AND rowHash() %% 5 = 0
        LIMIT 1
        OFFSET %(offset)d
    """ % {"digit": digit, "offset": other_example})
    data_array = data.as_matrix()
    rand_test_img = data_array[0][:-10].reshape(sq, sq)
    scores = data_array[0][-10:]
    explain_data = mldb.query("""
                label: %(digit)d,
                features: {* EXCLUDING(label)}
            })[explanation] AS *
        FROM digits_mnist 
        WHERE label = %(digit)d AND rowHash() %% 5 = 0
        LIMIT 1
        OFFSET %(offset)d
    """ % {"digit": digit, "offset": other_example})
    explain_img = rgb_explain(explain_data)
    fig = plt.figure(figsize=(8, 8))
    # plot digit image
    ax1 = plt.subplot2grid((4, 4), (0, 0), colspan=2, rowspan = 2)
    ax1.set_title("Fig1: You chose the digit below", fontsize=12, fontweight='bold')
    # plot explain matrix
    ax2 = plt.subplot2grid((4, 4), (0, 2), colspan=2, rowspan = 2)
    ax2.set_title("Fig2: Explain matrix picture of digit %d" %digit, fontsize=12, fontweight='bold')
    # plot scores
    ax3 = plt.subplot2grid((4, 4), (2, 0), colspan=4, rowspan = 2)
    greater_than_zero = scores >= 0
    lesser_than_zero = scores < 0
    ax3.barh(np.arange(len(scores))[greater_than_zero]-0.5, scores[greater_than_zero], color='#87CEFA', height=1)
    ax3.barh(np.arange(len(scores))[lesser_than_zero]-0.5, scores[lesser_than_zero], color='#E6E6FA', height=1)
    ax3.yaxis.set_ticks(np.arange(0, 10, 1))
    ax3.set_title("Fig3: Scores for each number - the number with the highest score wins!", fontsize=12, fontweight='bold')

In the representation of the explain matrix (figure 3), the green pixels help increase the score while the red pixels help decrease the score of the chosen digit. The explain matrix will tell us something about the pixels deemed most important. For example, if nothing was drawn in the top left corner of the picture during training, no information will be provided on the top left set of pixels in the explain matrix. During training, if a pixel is not part of the classification rules (i.e. not on any leaf), that pixel will not show up in the explain matrix.

Making a simple web app using MLDB plug-in functionality

We've built a very fun web app called MLPaint that uses everything we've shown here to do real-time recognition of digits. The app is shown in the Youtube video at the top of the notebook.

The app is built with an MLDB plugin; plugins allow us to extend functionality that we have seen so far. For more information, check out the documentation on plugins.

By running the cell below, the plugin is checkout from Github and loaded into MLDB:

In [13]:
mldb.put("/v1/plugins/mlpaint", {
    "type": "python",
    "params": {
        "address": "git://"
PUT http://localhost/v1/plugins/mlpaint
201 Created
  "config": {
    "params": {
      "address": "git://"
    "type": "python", 
    "id": "mlpaint"
  "state": "ok", 
  "type": "python", 
  "id": "mlpaint"

Now you can browse to the plugin UI.

NOTE: this only works if you're running this Notebook live, not if you're looking at a static copy on You can this demo by trying MLDB for free.

Where to next?

Check out the Image Processing with Convolutions demo if you are interested in building your own custom image manipulation algorithms.

Otherwise, check out the other Tutorials and Demos.

In [ ]: