How to plot a Learning Curve in Machine Learning Python?

Spread the love

Learning Curve in Machine Learning –

Learning curve visualize the performance (e.g. accuracy, recall) of a model on the training set and during cross-validation as the number of observations in the training set increases. They are commonly used to determine if our learning algorithm would benefit from gathering additional data.

How to plot a Learning Curve ?

Let’s read a dataset to work with.

import pandas as pd
import numpy as np

# Read data 
cancer = pd.read_csv('https://raw.githubusercontent.com/bprasad26/lwd/master/data/breast_cancer.csv')
cancer.head()

Now, Let’s split the data into a training and test set.

# split the data into training and test set
from sklearn.model_selection import train_test_split

X = cancer.drop('diagnosis', axis=1).copy()
y = cancer['diagnosis']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Next, we will create a training pipeline with data preprocessing.

# create a training pipeline with data preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import RandomForestClassifier

rf = make_pipeline(SimpleImputer(), StandardScaler(), RandomForestClassifier(random_state=42))

Now, let’s write a function to plot the learning curve.

Note – Please Don’t forget to change the scoring parameter and plot labels based on the metric that you are using.

# function for plotting learning curve
from sklearn.model_selection import learning_curve
import plotly.graph_objects as go
import numpy as np

def plot_learning_curves(estimator, X, y, cv):
    """
    Don't forget to change the scoring and plot labels
    based on the metric that you are using.
    """

    train_sizes, train_scores, test_scores = learning_curve(
        estimator=estimator,
        X=X,
        y=y,
        train_sizes=np.linspace(0.1, 1.0, 10),
        cv=cv,
        scoring="accuracy",
        random_state=42
    )
    train_mean = np.mean(train_scores, axis=1)
    test_mean = np.mean(test_scores, axis=1)

    fig = go.Figure()

    fig.add_trace(
        go.Scatter(
            x=train_sizes,
            y=train_mean,
            name="Training Accuracy",
            mode="lines",
            line=dict(color="blue"),
        )
    )

    fig.add_trace(
        go.Scatter(
            x=train_sizes,
            y=test_mean,
            name="Validation Accuracy",
            mode="lines",
            line=dict(color="green"),
        )
    )

    fig.update_layout(
        title="Learning Curves",
        xaxis_title="Number of training examples",
        yaxis_title="Accuracy",
    )

    fig.show()

Now let’s plot the learning curve.

plot_learning_curves(rf, X_train, y_train, cv=5)

We can see that validation accuracy kept increasing as we increase the training size. So it will be beneficial if we can find more training samples.

Function for plotting learning curve for regression problem.

def plot_learning_curves(estimator, X, y, cv):

    train_sizes, train_scores, test_scores = learning_curve(
        estimator=estimator,
        X=X,
        y=y,
        train_sizes=np.linspace(0.1, 1.0, 10),
        cv=cv,
        scoring="neg_mean_absolute_error",
        random_state=42
    )
    train_mean = np.mean(-train_scores, axis=1)
    test_mean = np.mean(-test_scores, axis=1)

    fig = go.Figure()

    fig.add_trace(
        go.Scatter(
            x=train_sizes,
            y=train_mean,
            name="Training MAE",
            mode="lines",
            line=dict(color="blue"),
        )
    )

    fig.add_trace(
        go.Scatter(
            x=train_sizes,
            y=test_mean,
            name="Validation MAE",
            mode="lines",
            line=dict(color="green"),
        )
    )

    fig.update_layout(
        title="Learning Curves",
        xaxis_title="Number of training examples",
        yaxis_title="Mean Absolute Error",
    )

    fig.show()

Rating: 1 out of 5.

Leave a Reply