Extra-Trees in Machine Learning

Spread the love

Extra-Trees –

When you are growing a tree in a Random Forest, at each node only a random subset of the features is considered foe splitting. It is possible to make trees even more random by also using random thresholds for each feature rather than searching for the best possible threshold (like regular Decision Trees do.)

A forest of such extremely random trees is called an Extremely Randomized Trees ensemble (or Extra-Trees for short). This technique trades more bias for a lower variance. It also makes Extra-Trees much faster to train than regular Random Forests because finding the best possible threshold for each feature at every node is one of the most time-consuming tasks of growing a tree.

You can create an Extra-Trees classifier using scikit-Learn’s ExtraTreesClassifier class. It’s API is identical to the RandomForestClassifier class. Similarly the ExtraTreesRegressor class has the same API as RandomForestRegressor class.

How to train a Extra-Trees Model in Sklearn ?

Let’s read a dataset to work with.

import pandas as pd
import numpy as np

url = 'https://raw.githubusercontent.com/bprasad26/lwd/master/data/breast_cancer.csv'
df = pd.read_csv(url)
df.head()

Next split the data into a training and test set.

from sklearn.model_selection import train_test_split

X = df.drop('diagnosis', axis=1)
y = df['diagnosis']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Now train a Extra-Trees in scikit-learn.

from sklearn.ensemble import ExtraTreesClassifier
from sklearn.metrics import accuracy_score

extra_clf = ExtraTreesClassifier()
extra_clf.fit(X_train, y_train)
y_pred = extra_clf.predict(X_test)
accuracy_score(y_test, y_pred)
# output
0.9736842105263158

Rating: 1 out of 5.

Leave a Reply