Python – Scatter Plot Different Classes

0

In this post, you will learn about the how to create scatter plots using Python which represents two or more classes while you are trying to solve machine learning classification problem.

As you work on the classification problem, you want to understand whether classes are linearly separable or they are non-linear. In other words, whether the classification problem is linear or non-linear. This, in turn, helps you decide on what kind of machine learning classification algorithms you might want to use.

In this post, you will learn how to use scatter plot to identify whether two or more classes are linearly separable or not. You may want to check what, when and how of scatter plot matrix which can also be used to determine whether the data is linearly separable or not by analyzing the pairwise or bi-variate relationships between different predictor variables.

Python Code for Scatter Plot

Here is the code to load the data and find out the class labels

df2 = pd.read_csv('/Users/apple/Downloads/user knowledge level - Sheet1.csv')
df2.head()
Fig 1. User Knowledge Level Data set

The code below can be used to scatter plot the classes such as very_low and Low while using the feature STG and SCG as X and Y axis.

plt.scatter(df2['STG'][(df2.UNS == 'very_low') | (df2.UNS == 'Very Low')], 
            df2['SCG'][(df2.UNS == 'very_low') | (df2.UNS == 'Very Low')], 
           marker='D', 
           color='red',
           label='Very Low')
plt.scatter(df2['STG'][df2.UNS == 'Low'], 
            df2['SCG'][df2.UNS == 'Low'], 
           marker='o', 
           color='blue',
           label='Low')
plt.xlabel('STG')
plt.ylabel('SCG')
plt.legend()
plt.show()
Scatter Plot representing two different classes
Fig 2. Scatter plot representing Very Low and Low classes

The above scatter plot could be achieved in one line by using category_scatter function from mlxtend python package authored by Dr. Sebastian Raschka. Here is the command:

from mlxtend.plotting import category_scatter
df['UNS'] = np.where(df['UNS'] == 'Very Low', 'very_low', df['UNS'])
fig = category_scatter(x='STG', y='SCG', label_col='UNS', 
                       data=df, legend_loc='upper right')
Ajitesh Kumar
Share.

Leave A Reply

Time limit is exhausted. Please reload the CAPTCHA.