Build a multi-class classification neural network in R in fifty lines of code

The R language allows for rapid prototyping of machine learning and neural network models. Having learned to create neural networks using Python, I found prototyping neural networks using R to be quick and easy. Even though I would still prefer Python owing to the flexibility it offers as a programming language and the advanced control you have over the algorithms, I see myself using R for simple and quick projects. 

In this tutorial, I am going to use the popular iris dataset to predict the species of flowers using a simple neural network. I will be using the neuralnet package to create a neural network and the tidyverse package for some handy tools. 

Let’s get started 

First, import the tidyverse and neuralnet packages. 


Now, let us take a look at the iris dataset. This iris dataset is available in R by default. So, we can straight away use it by calling iris 


When you print the dataset to the console, you will be able to see that the dataset has four feature columns and the label column, and there are 150 rows of data. The dataset contains data of only three species of flowers. 

The iris dataset

Since the label column consists of categorical variables, we need to convert the label column into a factor. You can do this by using the as_factor method. 

iris <- iris %>%  
  mutate(Species=as_factor(Species) )

Here, the %>% operator is a pipe operator provided by the tidyverse package. We can use the mutate method to convert the ‘Species’ column into a factor column. 

Data preprocessing

Now, let us visualize the dataset to see if we need to do any preprocessing. I am going to draw a boxplot to see if the dataset needs to be scaled and if there are any outliers. To that end, let me create a function to draw boxplots. 

draw_boxplot <- function(){ 
  iris %>%  
    pivot_longer(1:4, names_to="attributes") %>%  
    ggplot(aes(attributes, value, fill=attributes)) + 

The pivot_longer method pivots the feature columns into rows so that we will end up with a column that contains the names of the feature columns and their respective values. We then pass the name column as the x-axis and the value column as the y-axis of the ggplot function. Finally, we use the geom_boxplot method to draw the boxplot. Then, call the draw_boxplot method to draw a boxplot. 

The boxplot drawn before preprocessing

We can observe that the columns have different scales and the ‘Sepal.Width’ column has outliers. First, let us get rid of the outliers. I am going to use the squish method to remove the outliers. Here, note that I will not be removing the outlying data. Instead, I will only be setting the outlying rows of data to the maximum or minimum value. 

iris <- iris %>%  
  mutate(across(Sepal.Width, ~squish(.x, quantile(.x, c(0.05, 0.95)))))

This will squeeze my data between the 95th percentile and the 5th percentile.  

Now, let us also scale the columns using the scale method. The scale method in R uses Z-score normalization to normalize data.  

iris <- iris %>%  
  mutate(across(1:4, scale))

Once again, let us visualize the dataset to see if there are any improvements. 

The boxplot drawn after preprocessing

We can see that the columns have a similar scale and there are no outliers. Great!  

Splitting the dataset

Now that we are done with the preprocessing task, let us split the dataset into training data and test data. We will use 70% of the data as training data and the rest as the test data. 

While splitting the dataset, we need to make sure that the rows are apportioned randomly. So, let us first generate a vector of random integers. The total number of integers should be equal to 70% of the total number of rows in the dataset. 

training_data_rows <- floor(0.70 * nrow(iris))          
training_indices <- sample(c(1:nrow(iris)), training_data_rows)

We will get the number of integers we need by getting 70% of the total number of rows. Since we have 150 rows, this value will be 105. So, our training dataset will consist of 105 rows. We then generate a vector of 105 random integers between 1 and 150.  

Now, let us split our dataset using training_indices. 

training_data <- iris[training_indices,] 
test_data <- iris[-training_indices,]


What are you waiting for now? Let us train a neural network on our training data. 

To create a neural network, I am going to use the neuralnet package. I will be using the default settings and will be using two hidden layers with two neurons on each. By default, neuralnet uses the logistic function as the activation function. To see what the other default values are, check out its documentation. 

             data=training_data, hidden=c(2,2), linear.output = FALSE)

Here, the first argument is a formula specifying the y axis and the x-axis. Execute the code to start training our model.  

Once done, I am going to call plot(nn) to draw the architecture of the network.  

My neural network architecture

The weight values you get will obviously be different from mine since the weights are initialized randomly. Now, let us test our neural network model. 


I am going to create a predict function to get the performance of our model on both the test data and the training data.  

predict <- function(data){ 
  prediction <- data.frame(neuralnet::compute(nn,  

  labels <- c("setosa", "versicolor", "virginca") 
  prediction_label <- data.frame(max.col(prediction)) %>%  
    mutate(prediction=labels[max.col.prediction.]) %>%  
    select(2) %>%  

  table(data$Species, prediction_label) 

We are passing a dataset as an argument and then generating a confusion matrix using this method. To predict the species, we can use the compute method provided by the neuralnet package. Since the compute method gives us the probability of each output neuron, we use the max.col function to get the highest probability. The precited species will be the species with the highest probability value. 

Next, let us check the performance on our training data.

setosa versicolor virginca
Actual setosa 36 0 0
versicolor 0 32 0
virginca 0 0 37

As you can see, the accuracy is 100%. But hold on! This is just the training data. The real test is on the test data. We call it the test data for a reason! 

setosa versicolor virginca
Actual setosa 14 0 0
versicolor 0 17 1
virginca 0 0 13

Not bad at all. The model has got only one row wrong. Which gives us an accuracy of 97.7%. Now, you can play around with the hyperparameters such as the number of layers and the number of neurons on each layer to see if you can get a better accuracy value. 

You can access the full code used in this tutorial from my GitHub repo. 

Leave a Reply