This tutorial will explore how you can generate ROC for multiclass datasets in Python. As you know, ROC is generated for binary classification only, so we will extend the same concept for multiclass datasets. Check out my previous tutorial if you don’t know how to make ROC for binary class. Have a look at it as we will begin from where we finished in the previous tutorial.
Extension of Concept
As you know, we need a binary class for generating ROC. So, if we have a multiclass dataset, our first goal will be to convert the multi-class dataset into binary classes. This process is called binarization. We will use the most common approach, one-vs-all (OVA), for binarization and extending the concept of ROC to multiclass. I am opting for this approach as it is simpler to understand and apply.
In this approach, we take one class out of all the classes as a positive class and the rest of the classes as a negative class. Now, you have binary classes as positive and negative classes. Now calculate the AUC for it. Repeat the process for each class of the multiclass dataset. Finally, take the average of all the AUC scores.
Step 1: Importing Libraries
Let’s import all the necessary libraries. I will be working on the iris dataset as it has 3 classes.
import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.preprocessing import label_binarize from sklearn.linear_model import LogisticRegression from sklearn.multiclass import OneVsRestClassifier from sklearn.metrics import roc_curve, auc
Step 2: Loading the Dataset
Load the iris dataset from sklearn datasets. It consists of 3 classes with label encoded as 0,1 and 2.
iris = datasets.load_iris() X = iris.data y = iris.target
Step 3: Binarize
As we have 3 classes in our dataset, so after binarizing, like taking 1 class as positive and the rest as negative and continuing the same process for each class, our y will have a shape of 150×3 from 150×1.
print(y.shape) y_bin = label_binarize(y,classes=[0, 1, 2]) print(y_bin.shape)
Output:
(150,) (150, 3)
Step 4: OVA Method
As explained earlier, we have a direct function for the OVA method named OneVsRestClassifier
. It takes the binary classifier as a parameter and converts it into a multi-class classifier using the OVA approach explained earlier.
model = OneVsRestClassifier(LogisticRegression()) train_X, test_X, train_y, test_y = train_test_split(X,y_bin,test_size = 0.2, random_state=2023) model.fit(train_X, train_y) y_prob = model.predict_proba(test_X)
Step 5: Calculating AUC
Our iris dataset has 3 classes, so we need to store 3 fpr values, 3 tpr values, and 3 auc values. We will store these values in the list. We need fpr and tpr values to make the ROC curve. The roc_curve
function helps calculate these values. In addition to these values, the function returns one more value called threshold value, which is unnecessary in plotting the curve. So, to handle additional return arguments, we put _
.
classes_quantity = 3 fpr = [0,0,0] tpr = [0,0,0] auc_score = [0,0,0] for i in range(classes_quantity): fpr[i], tpr[i], _ = roc_curve(test_y[:, i],y_prob[:, i]) auc_score[i] = auc(fpr[i], tpr[i]) print(auc_score)
The above code will print out the list containing the AUC scores of each class according to the OVA method.
[1.0, 0.8038277511961722, 1.0]
Take the average of it to get the overall AUC score.
print(sum(auc_score) / classes_quantity)
Output:
0.934609250398724
Step 6: Plotting the ROC Curve
Since we have 3 classes, there will be 3 ROC curves, each curve for different classes. I am choosing red, orange, and blue colors for ROC 1, ROC 2, and ROC 3, respectively. I am using the zip
function to iterate over two things simultaneously. Here, we need to iterate over the colors list and the class list.
plt.figure(figsize=(8, 8)) colors = ['red', 'orange', 'blue'] for i, color in zip(range(classes_quantity), colors): plt.plot(fpr[i], tpr[i], color=color,label=f'Class {i} (AUC = {auc_score[i]:.2f})') plt.plot([0, 1], [0, 1], color="black", linestyle="--") plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('ROC Analysis for Multi-Class Classification') plt.legend(loc="lower right") plt.show()
Output:
The red curve can’t be seen as curve 3 and curve 1 perfectly overlapping each other. Thus, the blue curve can only be seen as plotted after the red curve.