Thursday, February 27, 2025

Visualizing Decision Trees in Scikit-learn

Visualizing Decision Trees in Scikit-learn

Decision Trees are one of the most intuitive machine learning models, and a great advantage is that they can be visualized to understand how decisions are made at each step. In this post, we will explore different ways to visualize Decision Trees using Python’s Scikit-learn library.

Why Visualize a Decision Tree?

Understanding the structure of a Decision Tree helps with:

  • Interpreting the model's decision-making process.
  • Identifying important features used for classification or regression.
  • Detecting overfitting when the tree is too deep.

1. Training a Decision Tree in Scikit-learn

First, let's train a Decision Tree using the Iris dataset for classification.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

# Load the dataset
iris = load_iris()
X, y = iris.data, iris.target

# Train a Decision Tree Classifier
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(X, y)
    
The previous code block consist of the following code lines:
  • Load the dataset:
    • iris = load_iris() - Loads the Iris dataset from sklearn.datasets. The dataset contains features of iris flowers and their corresponding species labels.
    • X, y = iris.data, iris.target - X contains the feature data (iris flower measurements), and y contains the target data (iris species labels).
  • Train a Decision Tree Classifier:
    • clf = DecisionTreeClassifier(max_depth=3, random_state=42) - Initializes a Decision Tree Classifier with a maximum depth of 3 to control the complexity of the tree and prevent overfitting. The random_state=42 ensures that the results are reproducible.
    • clf.fit(X, y) - Fits the Decision Tree Classifier (clf) to the Iris dataset. This step trains the classifier using the features (X) and target labels (y).

2. Visualizing with plot_tree

Scikit-learn provides the plot_tree function to directly visualize a trained Decision Tree.

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

# Plot the Decision Tree
plt.figure(figsize=(12,8))
plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True)
plt.show()
    
The previous code block consist of the following code lines:
  • Import necessary libraries:
    • import matplotlib.pyplot as plt - Imports the matplotlib.pyplot module, which is used for plotting graphs and figures.
    • from sklearn.tree import plot_tree - Imports the plot_tree function from sklearn.tree to visualize decision trees.
  • Plot the Decision Tree:
    • plt.figure(figsize=(12,8)) - Creates a figure with a specified size of 12x8 inches to plot the decision tree.
    • plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True) - Plots the trained decision tree clf. The filled=True argument colors the nodes according to the predicted class. The feature_names=iris.feature_names and class_names=iris.target_names add feature and class names to the plot, respectively. The rounded=True argument makes the nodes have rounded corners for a cleaner appearance.
    • plt.show() - Displays the decision tree plot on the screen.

Explanation:

  • The filled=True argument colors the nodes based on the predicted class.
  • The feature_names and class_names arguments add labels for better understanding.
  • The rounded=True makes the boxes have rounded corners for better readability.

3. Exporting Tree as a Graph using export_graphviz

Another approach is using Graphviz to create a graphical representation of the tree.

from sklearn.tree import export_graphviz
import graphviz

# Export the Decision Tree
dot_data = export_graphviz(clf, out_file=None, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True)

# Visualize using Graphviz
graph = graphviz.Source(dot_data)
graph.render("decision_tree")  # Saves the tree as a .pdf file
graph
    
The previous code block consist of the following code lines:
  • Import necessary libraries:
    • from sklearn.tree import export_graphviz - Imports the export_graphviz function, which is used to export a decision tree in the Graphviz DOT format.
    • import graphviz - Imports the graphviz module, which is used to render and visualize Graphviz DOT format files.
  • Export the Decision Tree:
    • dot_data = export_graphviz(clf, out_file=None, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True) - Exports the trained decision tree clf to the DOT format. The filled=True argument colors the nodes according to the predicted class. The feature_names=iris.feature_names and class_names=iris.target_names add feature and class names to the tree. The rounded=True argument ensures the nodes have rounded corners.
  • Visualize using Graphviz:
    • graph = graphviz.Source(dot_data) - Creates a Graphviz source object from the DOT data, which represents the decision tree.
    • graph.render("decision_tree") - Renders the decision tree and saves it as a .pdf file with the name decision_tree.pdf.
    • graph - Displays the decision tree visually.

Explanation:

  • The tree is exported as a DOT file format using export_graphviz.
  • Graphviz is used to render the visualization.
  • To view the output, run the script in a Jupyter Notebook or save it as an image/PDF.

4. Feature Importance in Decision Trees

Decision Trees can help us understand which features are most important in making predictions.

import numpy as np

# Get feature importances
feature_importances = clf.feature_importances_

# Print feature importance
for feature, importance in zip(iris.feature_names, feature_importances):
    print(f"{feature}: {importance:.4f}")
    
The previous code block consist of the following code lines:
  • Import the necessary library:
    • import numpy as np - Imports the NumPy library, which is used for numerical operations.
  • Get feature importances:
    • feature_importances = clf.feature_importances_ - Retrieves the feature importances from the trained decision tree classifier clf. The feature_importances_ attribute provides the relative importance of each feature in making predictions.
  • Print feature importances:
    • for feature, importance in zip(iris.feature_names, feature_importances): - Iterates over the feature names and their corresponding importances using zip to pair each feature with its importance value.
    • print(f"{feature}: {importance:.4f}") - Prints the name of each feature along with its importance value, formatted to four decimal places.

Explanation:

  • Higher values indicate that a feature plays a more significant role in decision-making.
  • This helps in feature selection by identifying which features contribute the most to the prediction.

5. Interactive Decision Tree Visualization

We can create an interactive tree visualization using dtreeviz, an external library.

Installation:

pip install dtreeviz
    

Usage:

from dtreeviz.trees import dtreeviz

# Generate the visualization
viz = dtreeviz(clf, X, y, target_name="species", feature_names=iris.feature_names, class_names=iris.target_names)

# Display the tree
viz.show()
    
The previous code block consist of the following code lines:
  • Import the necessary library:
    • from dtreeviz.trees import dtreeviz - Imports the dtreeviz function from the dtreeviz library, which is used for visualizing decision trees in a more interactive and detailed manner.
  • Generate the visualization:
    • viz = dtreeviz(clf, X, y, target_name="species", feature_names=iris.feature_names, class_names=iris.target_names) - Calls the dtreeviz function to generate a detailed visualization of the decision tree. It takes the following parameters:
      • clf - The trained decision tree classifier.
      • X - The feature matrix containing the input data.
      • y - The target values corresponding to the input data.
      • target_name="species" - The name of the target variable (in this case, "species").
      • feature_names=iris.feature_names - The list of feature names.
      • class_names=iris.target_names - The list of class names for the target variable (the types of species).
  • Display the tree:
    • viz.show() - Displays the generated decision tree visualization in an interactive format.

Conclusion

Visualizing Decision Trees helps in understanding model decisions and identifying overfitting. In this tutorial, we explored:

  • plot_tree: A quick way to visualize the tree.
  • export_graphviz: Exporting the tree as a DOT file for use with Graphviz.
  • Feature Importance: Understanding the significance of features in the decision process.
  • dtreeviz: Creating an interactive visualization for deeper insights.

Try these methods with your own datasets to enhance your machine learning models!

No comments:

Post a Comment