In this post, you will learn about the how to create scatter plots using Python which represents two or more classes while you are trying to solve machine learning classification problem.

As you work on the classification problem, you want to understand whether classes are linearly separable or they are non-linear. In other words, whether the classification problem is linear or non-linear. This, in turn, helps you decide on what kind of machine learning classification algorithms you might want to use.

In this post, you will learn how to use scatter plot to identify whether two or more classes are linearly separable or not.

## Python Code for Scatter Plot

Here is the code to load the data and find out the class labels

df2 = pd.read_csv('/Users/apple/Downloads/user knowledge level - Sheet1.csv')


The code below can be used to scatter plot the classes such as very_low and Low while using the feature STG and SCG as X and Y axis.

plt.scatter(df2['STG'][(df2.UNS == 'very_low') | (df2.UNS == 'Very Low')],
df2['SCG'][(df2.UNS == 'very_low') | (df2.UNS == 'Very Low')],
marker='D',
color='red',
label='Very Low')
plt.scatter(df2['STG'][df2.UNS == 'Low'],
df2['SCG'][df2.UNS == 'Low'],
marker='o',
color='blue',
label='Low')
plt.xlabel('STG')
plt.ylabel('SCG')
plt.legend()
plt.show()


The above scatter plot could be achieved in one line by using category_scatter function from mlxtend python package authored by Dr. Sebastian Raschka. Here is the command:

from mlxtend.plotting import category_scatter
df['UNS'] = np.where(df['UNS'] == 'Very Low', 'very_low', df['UNS'])
fig = category_scatter(x='STG', y='SCG', label_col='UNS',
data=df, legend_loc='upper right')