
In this post you will learn How to visualize a decision Tree Model.
Visualize a Decision Tree Model –
Let’s read a dataset to work with
# import libraries
import pandas as pd
import numpy as np
# read a dataset
url = 'https://raw.githubusercontent.com/bprasad26/lwd/master/data/breast_cancer.csv'
df = pd.read_csv(url)
df.head()

Split the data into training and test set.
from sklearn.model_selection import train_test_split
X = df.drop('diagnosis', axis=1)
y = df['diagnosis']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
Now, let’s plot the decision tree model.
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# create a Decision tree classifier
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
# train it on the training data
clf.fit(X_train, y_train)
# plot the decision tree
plt.figure(figsize=(10, 3), dpi=300)
tree_dot = plot_tree(clf,
feature_names=X_train.columns.tolist(),
class_names=['B', 'M'],
filled=True
)
plt.savefig('decision_tree.png')
plt.show()
