Python keras for multi-class classification model using IRIS dataset
In this post, you will learn about how to train a neural network for multi-class classification using Python Keras libraries and Sklearn IRIS dataset. As a deep learning enthusiasts, it will be good to learn about how to use Keras for training a multi-class classification neural network.
The following topics are covered in this post:
Training a neural network for multi-class classification using Keras will require the following seven steps to be taken:
Here is the Python Keras code for training a neural network for multi-class classification of IRIS dataset. Pay attention to some of the following important aspects in the code given below:
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
#
# Import Keras modules
#
from keras import models
from keras import layers
from keras.utils import to_categorical
#
# Create the network
#
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(4,)))
network.add(layers.Dense(3, activation='softmax'))
#
# Compile the network
#
network.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
#
# Load the iris dataset
#
iris = datasets.load_iris()
X = iris.data
y = iris.target
#
# Create training and test split
#
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)
#
# Create categorical labels
#
train_labels = to_categorical(y_train)
test_labels = to_categorical(y_test)
#
# Fit the neural network
#
network.fit(X_train, train_labels, epochs=20, batch_size=40)
Once the network is fit, one can test the accuracy of network using the test data using the following code. Note the usage of the function evaluate.
#
# Get the accuracy of test data set
#
test_loss, test_acc = network.evaluate(X_test, test_labels)
#
# Print the test accuracy
#
print('Test Accuracy: ', test_acc, '\nTest Loss: ', test_loss)
Here is the summary of what you learned in relation to how to use Keras for training a multi-class classification model using neural network:
Last updated: 25th Jan, 2025 Have you ever wondered how to seamlessly integrate the vast…
Hey there! As I venture into building agentic MEAN apps with LangChain.js, I wanted to…
Software-as-a-Service (SaaS) providers have long relied on traditional chatbot solutions like AWS Lex and Google…
Retrieval-Augmented Generation (RAG) is an innovative generative AI method that combines retrieval-based search with large…
The combination of Retrieval-Augmented Generation (RAG) and powerful language models enables the development of sophisticated…
Have you ever wondered how to use OpenAI APIs to create custom chatbots? With advancements…