Thursday, February 27, 2025

Decision Trees for Classification

Decision Trees for Classification

Decision trees are one of the most popular machine learning algorithms for both classification and regression tasks. In this post, we will focus on Decision Trees for classification tasks, where the goal is to predict the class label of an object based on its features.

What is a Decision Tree?

A Decision Tree is a flowchart-like structure where each internal node represents a "test" or "decision" on an attribute (e.g., whether a feature is greater than a threshold value), each branch represents the outcome of that decision, and each leaf node represents a class label (the outcome of the classification). The tree is constructed by splitting the data at each internal node based on the most important features, with the goal of classifying the data into distinct classes.

Key Features of Decision Trees:

  • Interpretability: One of the biggest advantages of decision trees is that they are easy to interpret and visualize.
  • Non-linear decision boundaries: Unlike linear models, decision trees can handle non-linear decision boundaries.
  • Handles both numerical and categorical data: Decision trees can handle both types of data without needing feature scaling.
  • Overfitting risk: Decision trees are prone to overfitting, especially when the tree is deep, meaning it has too many branches.

Advantages and Disadvantages of Decision Trees

Advantages:

  • Simple to understand and interpret.
  • Can handle both categorical and numerical data.
  • Requires little data preprocessing, such as normalization or scaling.
  • Can handle missing values and can still make a classification based on the remaining features.

Disadvantages:

  • Prone to overfitting, especially with deep trees.
  • Not robust to small changes in the data.
  • Can be biased towards features with more levels (many categories in categorical features).

Decision Tree Classifier in Python

In this section, we will implement a Decision Tree classifier using Scikit-learn and demonstrate how to classify data.

1. Load the Dataset

For this example, we will use the famous Iris dataset, which contains measurements of three types of iris flowers (Setosa, Versicolor, and Virginica) and the goal is to classify the flowers based on these features.
from sklearn.datasets import load_iris
import pandas as pd

# Load the Iris dataset
iris = load_iris()
X = iris.data
y = iris.target

# Create a DataFrame for better visualization
df = pd.DataFrame(X, columns=iris.feature_names)
df['target'] = y
df.head()
    
The previous code block consist of the following code lines:
  • Import necessary libraries:
    • from sklearn.datasets import load_iris - Imports the load_iris function from the sklearn.datasets module to load the Iris dataset.
    • import pandas as pd - Imports the pandas library, which is useful for data manipulation and visualization.
  • Load the Iris dataset:
    • iris = load_iris() - Loads the Iris dataset into the variable iris. This dataset contains data about the features and species of Iris flowers.
    • X = iris.data - Extracts the feature matrix (sepal length, sepal width, petal length, and petal width) from the dataset and stores it in X.
    • y = iris.target - Extracts the target labels (the species of the Iris flowers) from the dataset and stores it in y.
  • Create a DataFrame for better visualization:
    • df = pd.DataFrame(X, columns=iris.feature_names) - Creates a pandas DataFrame from the feature matrix X and labels the columns using the feature names from the Iris dataset.
    • df['target'] = y - Adds a new column named 'target' to the DataFrame df, containing the target labels (species) from y.
    • df.head() - Displays the first five rows of the DataFrame df to provide a preview of the data, including the features and target labels.

2. Train the Decision Tree Classifier

Now, let’s train a Decision Tree model using the Scikit-learn DecisionTreeClassifier. We will first split the dataset into training and testing sets, then train the classifier on the training set.
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Initialize and train the Decision Tree classifier
dt_classifier = DecisionTreeClassifier(random_state=42)
dt_classifier.fit(X_train, y_train)
    
The previous code block consist of the following code lines:
  • Import necessary libraries:
    • from sklearn.model_selection import train_test_split - Imports the train_test_split function from the sklearn.model_selection module to split the dataset into training and testing sets.
    • from sklearn.tree import DecisionTreeClassifier - Imports the DecisionTreeClassifier from the sklearn.tree module to create and train a decision tree model for classification tasks.
  • Split the data into training and testing sets:
    • X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) - Splits the feature matrix X and target labels y into training and testing sets, with 30% of the data allocated for testing. The random_state=42 ensures reproducibility of the data split.
  • Initialize and train the Decision Tree classifier:
    • dt_classifier = DecisionTreeClassifier - Initializes the DecisionTreeClassifier object but has not yet trained the model. To complete the initialization, the model should be instantiated using dt_classifier = DecisionTreeClassifier().

3. Evaluate the Model

After training the model, let’s evaluate its performance by making predictions on the test set and calculating the accuracy.
from sklearn.metrics import accuracy_score

# Make predictions on the test set
y_pred = dt_classifier.predict(X_test)

# Calculate the accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")
    
The previous code block consist of the following code lines:
  • Import the necessary metric:
    • from sklearn.metrics import accuracy_score - Imports the accuracy_score function from the sklearn.metrics module to evaluate the performance of the model based on its accuracy.
  • Make predictions on the test set:
    • y_pred = dt_classifier.predict(X_test) - Uses the trained dt_classifier to predict the target values for the test set X_test, storing the predicted labels in y_pred.
  • Calculate and print the accuracy:
    • accuracy = accuracy_score(y_test, y_pred) - Compares the predicted labels y_pred with the true labels y_test to calculate the accuracy of the model. The accuracy is the proportion of correct predictions.
    • print(f"Accuracy: {accuracy:.2f}") - Prints the calculated accuracy, formatting the value to two decimal places using .2f.
When the code is executed the following accuracy value is obtained:
Accuracy: 1.00
The acchieved accuracy using decision tree classifier is equal to 1.00 indicating perfect classification performance. The next step is to visualize the decision tree.

4. Visualize the Decision Tree

One of the advantages of Decision Trees is their interpretability. We can visualize the trained decision tree using Scikit-learn’s `plot_tree` function.
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

# Visualize the decision tree
plt.figure(figsize=(12,8))
plot_tree(dt_classifier, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True)
plt.show()
    
The previous code block consist of the following code lines:
  • Import necessary libraries for visualization:
    • from sklearn.tree import plot_tree - Imports the plot_tree function from sklearn.tree to plot the structure of the decision tree.
    • import matplotlib.pyplot as plt - Imports matplotlib.pyplot to handle the visualization and plotting of the decision tree.
  • Create a figure for the plot:
    • plt.figure(figsize=(12,8)) - Initializes a new figure with a specified size of 12 by 8 inches for a clear and readable visualization.
  • Plot the decision tree:
    • plot_tree(dt_classifier, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True) - Uses the plot_tree function to create a visual representation of the trained decision tree (dt_classifier). The arguments:
      • filled=True - Fills the nodes with colors to represent the class distribution.
      • feature_names=iris.feature_names - Specifies the feature names from the Iris dataset to label the features in the tree.
      • class_names=iris.target_names - Specifies the class names (species) for the target labels.
      • rounded=True - Rounds the corners of the nodes for aesthetic purposes.
  • Display the plot:
    • plt.show() - Displays the decision tree plot.
After executing the code the following figure is obtained.
2025-02-27T10:20:10.088500 image/svg+xml Matplotlib v3.9.2, https://matplotlib.org/
Figure 1 - Graphical representation of decision tree classifier

The plot above shows the decision tree and how it splits the data at each node. Each node represents a feature and a threshold that is used to split the data. The leaf nodes show the predicted class for each partition.

Pruning the Decision Tree

Decision Trees can easily become too complex and overfit the data if they grow too deep. One way to mitigate overfitting is by pruning the tree. Pruning involves setting a maximum depth for the tree or requiring a minimum number of samples at a node.

Prune the Tree Using Maximum Depth

We can limit the depth of the tree to prevent overfitting by setting the `max_depth` parameter.
# Train the decision tree with a maximum depth of 3
dt_classifier_pruned = DecisionTreeClassifier(max_depth=3, random_state=42)
dt_classifier_pruned.fit(X_train, y_train)

# Visualize the pruned decision tree
plt.figure(figsize=(12,8))
plot_tree(dt_classifier_pruned, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True)
plt.show()
    
The previous code block consist of the following code lines:
  • Train a pruned decision tree with a maximum depth of 3:
    • dt_classifier_pruned = DecisionTreeClassifier(max_depth=3, random_state=42) - Initializes a new DecisionTreeClassifier with a maximum depth of 3, which limits the tree's complexity. The random_state=42 ensures reproducibility.
    • dt_classifier_pruned.fit(X_train, y_train) - Trains the decision tree classifier on the training data (X_train) and corresponding labels (y_train).
  • Create a figure for the pruned tree plot:
    • plt.figure(figsize=(12,8)) - Initializes a new figure for plotting the decision tree, setting the figure size to 12 by 8 inches for clear visualization.
  • Plot the pruned decision tree:
    • plot_tree(dt_classifier_pruned, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True) - Uses the plot_tree function to create a visual representation of the pruned decision tree (dt_classifier_pruned). The arguments:
      • filled=True - Fills the nodes with colors to represent the class distribution.
      • feature_names=iris.feature_names - Specifies the feature names from the Iris dataset to label the features in the tree.
      • class_names=iris.target_names - Specifies the class names (species) for the target labels.
      • rounded=True - Rounds the corners of the nodes for aesthetic purposes.
  • Display the pruned tree plot:
    • plt.show() - Displays the pruned decision tree plot.
After the code is executed the graphical representaion of a decision tree classifier with limited depth is shown in Figure 2.
2025-02-27T10:20:10.455712 image/svg+xml Matplotlib v3.9.2, https://matplotlib.org/
Figure 2 - Decision Tree Classifier with limited depth to 3

Conclusion

Decision Trees are a powerful tool for classification tasks, offering advantages like interpretability and the ability to handle both numerical and categorical data. However, they can easily overfit, which can be mitigated by pruning the tree or using ensemble methods like Random Forests.

In this post, we have covered the basics of Decision Trees, including how to implement and evaluate a Decision Tree classifier using Scikit-learn, as well as how to visualize and prune the tree to improve performance. Experiment with different datasets to explore the full potential of Decision Trees!

No comments:

Post a Comment