My Computer Can Read! Random Forest Classification in R

Marcus Codrescu
Dev Genius
Published in
5 min readJun 14, 2023

--

Photo by Volkan Olmez on Unsplash

Introduction

Let’s teach a computer how to read handwritten digits with a random forest classifier in R.

We will use the MNIST dataset of handwritten digits to train and test our model. You can download the dataset yourself from this GitHub page: MNIST files in PNG format

MNIST Dataset Sample (Source: researchgtate.net)

Prepare The Data

The trick to making this work is to encode the images in a format that our computer can understand. Every image is comprised of a matrix of RGB values, so we can use the png::readPNG() function from the {png} package to read these values into an R matrix.

image <- png::readPNG("train/0/16585.png")

dim(image)

#> [1] 28 28

Each of the images in the MNIST dataset is 28 by 28 pixels. We can flatten this out across a single dimension to create a dataset with 784 features.

# Flatten a Matrix
image_flat <- as.vector(
png::readPNG("train/0/16585.png")
)

length(image_flat)

#> [1] 784
Flatten a Matrix (Source: enriquegit.github.io)

If the images were much larger, we could have resized them first using the imager::resize() function from the {imager} package to prevent an out of memory error.

Jumping Right In

One of the easiest ways for us to get started is to create a new project in RStudio from Version Control (i.e. GitHub).

After creating our project, our directory should look like this.

list.files()

#> [1] "download-and-convert-to-png.py"
#> [2] "LICENSE"
#> [3] "README.md"
#> [4] "test"
#> [5] "test.csv"
#> [6] "train"
#> [7] "train.csv"

There are two files in the directory: train.csv and test.csv that have the respective file paths and labels for each image. We can read this into our session using readr::read_csv().

# Read the labels
training_images <- readr::read_csv(
"train.csv",
col_types = "cf"
)

testing_images <- readr::read_csv(
"test.csv",
col_types = "cf"
)

head(training_images)

#> # A tibble: 6 × 2
#> filepath label
#> <chr> <fct>
#> 1 train/0/16585.png 0
#> 2 train/0/24537.png 0
#> 3 train/0/25629.png 0
#> 4 train/0/20751.png 0
#> 5 train/0/34730.png 0
#> 6 train/0/15926.png 0

Notice that I specified the column types of the csv file using the col_types = “cf” (character and factor) parameter because the {parsnip} package requires the classification label to be a factor.

Creating The Training Dataset

The next step is to create the training dataset. We can create an empty matrix the size of the dataset and loop through each of the image files to read them in as a matrix and flatten to a vector.

# Prepare Training Data
n <- length(training_images$filepath)
training_matrix <- matrix(
nrow = n,
ncol = 28 * 28
)

for (i in 1:n){
training_matrix[i, ] <-
as.vector(
png::readPNG(
training_images$filepath[i]
)
)
}

To ensure the training data works with the model fitting function we convert it to a data frame and bind columns with the labels.

training_data <- cbind(
dplyr::select(training_images, label),
as.data.frame(
training_matrix
)
)

Fitting the Model

Now we can fit the model to the training data. We can also save the result as a .rds file, so we don’t have to repeat the training process every time we want to make predictions. It takes my computer about 6 minutes to fit the model.

# Model Fitting
rf_fit <- parsnip::fit(
parsnip::rand_forest(
mode = "classification"
),
data = training_data,
formula = label ~ .
)

# Save model
saveRDS(rf_fit, "mnist_model_fit.rds")

# rf_fit <- readRDS("mnist_model_fit.rds")

Creating the Testing Dataset

We can now use the exact same process as we did with the training data set to prepare the testing data set.

# Prepare Test Data
n_test <- length(testing_images$filepath)
testing_matrix <- matrix(
nrow = n,
ncol = 28 * 28
)

for (i in 1:n_test){
testing_matrix[i, ] <-
as.vector(
png::readPNG(
testing_images$filepath[i]
)
)
}

testing_data <- na.omit(
cbind(
dplyr::select(testing_images, label),
as.data.frame(
testing_matrix
)
)
)

I added na.omit() around the training data because some of the values were missing and it was causing issues with the prediction function.

Make Predictions

Let’s make some predictions and see how we did! We can use the predict function from base R and pass it the parsnips object and the testing data set. The metrics() function from the {yardstick} package makes it easy to calculate the performance metrics.

# Model Evaluation
predictions <- predict(
rf_fit,
testing_data
)

final_result <-
dplyr::bind_cols(
predictions,
dplyr::select(
testing_data,
label
)
)

yardstick::metrics(
final_result,
truth = "label",
estimate = ".pred_class"
)

#> # A tibble: 2 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy multiclass 0.97
#> 2 kap multiclass 0.967

97% percent accuracy for our model! Not bad at all. So how many did it get wrong?

wrong_idx <- which(final_result$label != final_result$.pred_class)
right_idx <- which(final_result$label == final_result$.pred_class)

length(wrong_idx)

#> [1] 300

300 of 10,000 testing images were labeled incorrectly. Let’s plot a random few of them to see why.

random_right <- sample(right_idx, 3)
random_wrong <- sample(wrong_idx, 3)

# Plot the mistakes
ggplot2::ggplot(
data = data.frame(
x = seq(1, 10, length.out = 6),
y = 1,
images = testing_images$filepath[c(random_right, random_wrong)]
),
ggplot2::aes(
x,
y,
image = images,
label = paste(final_result$.pred_class[c(random_right, random_wrong)])
)
) +
ggimage::geom_image(
size=.10
) +
ggplot2::scale_y_continuous(
limits = c(0, 2)
) +
ggplot2::scale_x_continuous(
limits = c(0, 11)
) +
ggplot2::geom_text(
size = 10,
nudge_y = 0.25,
color = c("green", "green", "green", "red", "red", "red")
) +
ggplot2::theme_void()
Classification Results

We can see that the model was close because the numbers it labeled incorrectly are difficult to distinguish.

Conclusion

I hope you enjoyed this article. Thank you for reading and I wish you the very best! Until the next one.

Code From Article: mnist_classification.R

--

--

Data analytics professional using R, SQL, Docker, TagUI, and more.