Validation Curve in Machine Learning –
Many algorithms in Machine Learning contains hyperparameters that must be chosen before the training process begins. For example one of the most important hyperparameter in random forest is the number of trees in the forest. Most often these hyperparameters values are selected during model selection but it is occasionally useful to visualize how the model performance changes as we change the values of these hyperparameters.
In scikit-Learn, we can calculate validation curve using validation_curve, which contains three important parameters.
- param_name – is the name of the hyperparameter to vary.
- param_range – is the value of the hyperparameter to use
- scoring – is the evaluation metric used to judge the model.
How to plot Validation curve in Python?
Let’s read a dataset to work with.
import pandas as pd import numpy as np # Read data cancer = pd.read_csv('https://raw.githubusercontent.com/bprasad26/lwd/master/data/breast_cancer.csv') cancer.head()
Now split the data into training and test set.
# split the data into training and test set from sklearn.model_selection import train_test_split X = cancer.drop('diagnosis', axis=1).copy() y = cancer['diagnosis'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
Initiate a random forest classifier. You can use any other algorithm that you like.
from sklearn.ensemble import RandomForestClassifier clf = RandomForestClassifier(random_state=42)
Function for plotting validation curves.
from sklearn.model_selection import validation_curve import plotly.graph_objects as go def plot_validation_curves(estimator, X, y, param_name, param_range, cv): train_scores, test_scores = validation_curve( estimator=estimator, X=X, y=y, param_name=param_name, param_range=param_range, cv=cv, ) train_mean = np.mean(train_scores, axis=1) test_mean = np.mean(test_scores, axis=1) fig = go.Figure() fig.add_trace( go.Scatter( x=param_range, y=train_mean, name="Training Accuracy", mode="lines", line=dict(color="Blue"), ) ) fig.add_trace( go.Scatter( x=param_range, y=test_mean, name="Validation Accuracy", mode="lines", line=dict(color="Green"), ) ) fig.update_layout( title="Validation Curves", xaxis_title=param_name, yaxis_title="Accuracy" ) fig.show()
Now, let’s plot the validation curve
param_range = np.arange(3, 30, 3) plot_validation_curves(clf, X_train, y_train, "max_depth", param_range, 5)
We can see that maximum depth of the tree is best when it is set to 6 as validation accuracy is highest at this depth. And as we increase the depth the accuracy starts to decrease and remains flat.