Identifying Important Features in Random Forest –
One of the major benefits of decision trees is interpretability. Specifically, we can visualize the entire model. However a random forest model is comprised of tens, hundreds, even thousand of decision tree. This make a simple, intuitive visualization of a random forest model impractical. That said there is another option we can compare and visualize the relative importance of each features.
Let’s see how to do it. First read a dataset to work with.
import pandas as pd import numpy as np url = 'https://raw.githubusercontent.com/bprasad26/lwd/master/data/breast_cancer.csv' df = pd.read_csv(url) df.head()
Split the data into a training and test set.
from sklearn.model_selection import train_test_split X = df.drop('diagnosis', axis=1) y = df['diagnosis'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
Now, train a random forest model and visualize the important features of the model.
from sklearn.ensemble import RandomForestClassifier import plotly.graph_objects as go # create a random forest classifier object rf = RandomForestClassifier() # train a model rf.fit(X_train, y_train) # calculate feature importances importances = rf.feature_importances_ # sort features in descending order indices = np.argsort(importances)[::-1] # create a list of all columns all_features = X_train.columns.tolist() # Create the feature importance plot importances_list =  feat_labels_list =  for i in range(X_train.shape): importances_list.append(importances[indices[i]]) feat_labels_list.append(all_features[indices[i]]) fig = go.Figure() fig.add_trace( go.Bar( x=importances_list, y=feat_labels_list, orientation="h", marker_color="#329932", ) ) fig.update_layout( yaxis=dict(title="Features", autorange="reversed"), xaxis=dict(title="Random Forest Feature Importance"), ) fig.show()
Related Posts –
- A Gentle Introduction to Random Forest in Machine Learning
- How to Train a Random Forest Regressor in Sklearn?
- A Gentle Introduction to Decision Tree in Machine Learning
- How to Visualize a Decision Tree Model?