Python – How to Draw Confusion Matrix using Matplotlib

In this post, you will learn about how to draw / show confusion matrix using Matplotlib Python package. It is important to learn this technique as it will come very handy in assessing the machine learning model performance of classification models trained using different classification algorithms.

Confusion Matrix using Matplotlib

In order to demonstrate the confusion matrix using Matplotlib, let’s fit a pipeline estimator to the Sklearn breast cancer dataset using StandardScaler (for standardising the dataset) and Random Forest Classifier as the machine learning algorithm. 

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import make_pipeline
#
# Load the breast cancer data set
#
bc = datasets.load_breast_cancer()
X = bc.data
y = bc.target
#
# Create training and test split
#
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=1, stratify=y)
#
# Create the pipeline
#
pipeline = make_pipeline(StandardScaler(), 
RandomForestClassifier(n_estimators=10, max_features=5, max_depth=2, random_state=1))
#
# Fit the Pipeline estimator
#
pipeline.fit(X_train, y_train)

Once an estimator is fit to the training data set, nest step is to print the confusion matrix. In order to do that, the following steps will need to be followed:

  • Get the predictions. Predict method on the instance of estimator (pipeline) is invoked.
  • Create the confusion matrix using actuals and predictions for the test dataset. The confusion_matrix method of sklearn.metrics is used to create the confusion matrix array.
  • Method matshow is used to print the confusion matrix box with different colors. In this example, the blue color is used. The method matshow is used to display an array as a matrix.
  • In addition to the usage of matshow method, it is also required to loop through the array to print the prediction outcome in different boxes.
#
# Get the predictions
#
y_pred = pipeline.predict(X_test)
#
# Calculate the confusion matrix
#
conf_matrix = confusion_matrix(y_true=y_test, y_pred=y_pred)
#
# Print the confusion matrix using Matplotlib
#
fig, ax = plt.subplots(figsize=(7.5, 7.5))
ax.matshow(conf_matrix, cmap=plt.cm.Blues, alpha=0.3)
for i in range(conf_matrix.shape[0]):
    for j in range(conf_matrix.shape[1]):
        ax.text(x=j, y=i,s=conf_matrix[i, j], va='center', ha='center', size='xx-large')

plt.xlabel('Predictions', fontsize=18)
plt.ylabel('Actuals', fontsize=18)
plt.title('Confusion Matrix', fontsize=18)
plt.show()

This is how the confusion matrix will look like:

Confusion Matrix representing predictions on breast cancer test dataset
Fig 1. Confusion Matrix representing predictions on breast cancer test dataset

Confusion Matrix using Mlxtend Package

Here is another package, mlxtend.plotting (by Dr. Sebastian Rashcka) which can be used to draw or show confusion matrix. It is much simpler and easy to use than drawing the confusion matrix in the earlier section. All you need to do is import the method, plot_confusion_matrix and pass the confusion matrix array to the parameter, conf_mat. The green color is used to create the show the confusion matrix.

from mlxtend.plotting import plot_confusion_matrix

fig, ax = plot_confusion_matrix(conf_mat=conf_matrix, figsize=(6, 6), cmap=plt.cm.Greens)
plt.xlabel('Predictions', fontsize=18)
plt.ylabel('Actuals', fontsize=18)
plt.title('Confusion Matrix', fontsize=18)
plt.show()

Here is how the confusion matrix will look like:

Confusion Matrix drawn using Mlxtend plot_confusion_matrix method
Fig 2. Confusion Matrix drawn using Mlxtend plot_confusion_matrix method
Ajitesh Kumar
Follow me

Ajitesh Kumar

I have been recently working in the area of Data Science and Machine Learning / Deep Learning. In addition, I am also passionate about various different technologies including programming languages such as Java/JEE, Javascript, Python, R, Julia etc and technologies such as Blockchain, mobile computing, cloud-native technologies, application security, cloud computing platforms, big data etc. I would love to connect with you on Linkedin.
Posted in Data Science, Machine Learning, Python. Tagged with , , , .

Leave a Reply

Your email address will not be published. Required fields are marked *

Time limit is exhausted. Please reload the CAPTCHA.