How to Visualize a Decision Tree Model?

Spread the love

In this post you will learn How to visualize a decision Tree Model.

Visualize a Decision Tree Model –

Let’s read a dataset to work with

# import libraries
import pandas as pd
import numpy as np

# read a dataset
url = 'https://raw.githubusercontent.com/bprasad26/lwd/master/data/breast_cancer.csv'
df = pd.read_csv(url)
df.head()

Split the data into training and test set.

from sklearn.model_selection import train_test_split

X = df.drop('diagnosis', axis=1)
y = df['diagnosis']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Now, let’s plot the decision tree model.

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

# create a Decision tree classifier
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
# train it on the training data
clf.fit(X_train, y_train)

# plot the decision tree
plt.figure(figsize=(10, 3), dpi=300)
tree_dot = plot_tree(clf, 
                     feature_names=X_train.columns.tolist(), 
                     class_names=['B', 'M'],
                     filled=True
)
plt.savefig('decision_tree.png')
plt.show()

Rating: 1 out of 5.

Leave a Reply