How to Plot a Validation Curve in Machine Learning Python?

Spread the love

Validation Curve in Machine Learning –

Many algorithms in Machine Learning contains hyperparameters that must be chosen before the training process begins. For example one of the most important hyperparameter in random forest is the number of trees in the forest. Most often these hyperparameters values are selected during model selection but it is occasionally useful to visualize how the model performance changes as we change the values of these hyperparameters.

In scikit-Learn, we can calculate validation curve using validation_curve, which contains three important parameters.

  1. param_name – is the name of the hyperparameter to vary.
  2. param_range – is the value of the hyperparameter to use
  3. scoring – is the evaluation metric used to judge the model.

How to plot Validation curve in Python?

Let’s read a dataset to work with.

import pandas as pd
import numpy as np

# Read data 
cancer = pd.read_csv('https://raw.githubusercontent.com/bprasad26/lwd/master/data/breast_cancer.csv')
cancer.head()

Now split the data into training and test set.

# split the data into training and test set
from sklearn.model_selection import train_test_split

X = cancer.drop('diagnosis', axis=1).copy()
y = cancer['diagnosis']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Initiate a random forest classifier. You can use any other algorithm that you like.

from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(random_state=42)

Function for plotting validation curves.

from sklearn.model_selection import validation_curve
import plotly.graph_objects as go 

def plot_validation_curves(estimator, X, y, param_name, param_range, cv):
    train_scores, test_scores = validation_curve(
        estimator=estimator,
        X=X,
        y=y,
        param_name=param_name,
        param_range=param_range,
        cv=cv,
    )
    train_mean = np.mean(train_scores, axis=1)
    test_mean = np.mean(test_scores, axis=1)

    fig = go.Figure()

    fig.add_trace(
        go.Scatter(
            x=param_range,
            y=train_mean,
            name="Training Accuracy",
            mode="lines",
            line=dict(color="Blue"),
        )
    )

    fig.add_trace(
        go.Scatter(
            x=param_range,
            y=test_mean,
            name="Validation Accuracy",
            mode="lines",
            line=dict(color="Green"),
        )
    )

    fig.update_layout(
        title="Validation Curves", xaxis_title=param_name, yaxis_title="Accuracy"
    )

    fig.show()

Now, let’s plot the validation curve

param_range = np.arange(3, 30, 3)
plot_validation_curves(clf, X_train, y_train, "max_depth", param_range, 5)

We can see that maximum depth of the tree is best when it is set to 6 as validation accuracy is highest at this depth. And as we increase the depth the accuracy starts to decrease and remains flat.

Rating: 1 out of 5.

Leave a Reply