Wednesday, December 11, 2024

Linear Regression Explained with scikit-learn

Linear regression is one of the simplest and most widely used algorithms in machine learning for predicting continuous numerical values. It assumes a linear relationship between the input features and the target variable, making it easy to understand, interpret, and implement. In this post, we will explain what linear regression is, how it works, and how to implement it using the popular Python library, scikit-learn.

What is Linear Regression?

Imagine you have a basket of chocolates and a piggy bank full of coins. You want to know:
If I have more coins, can I buy more chocoloates?

We can draw a straight line (Figure 1) assuming a linear relationship between number of coins and the number of chocolates.
Figure 1 - The linear relationship between number of coins and number of chocolates.
This line is like a magic rule that tells us:
  • How many chocolates you can get if you have a certain number of coins?
  • IF you have more coins, you'll probably get more chocolates!
We use math to draw the best line that maches how coins and chocolates work together. That's called linear regression!
Imagine we have a set of points on a graph. Each point shows how many conis you have and how many chocolates you can buy. Linear regression helps us find the best straight line that goes through (or near) those points. The equation for that line can be written as: \begin{equation} y = mx + b \end{equation} where:
  • \(y\) is the number of chocolates you get,
  • \(x\) is the number of coins you have,
  • \(m\) is the slope of the line (how much of the number of chocolates changes when you get more coins), and
  • \(b\) is the y-intercept (how many chocolates you would get even if you had 0 coins).
Linear regression is a statistical method used to model the relationship between a dependent variable (also called the target or response variable) and one or more independent variables (also known as features or predictors).
The goal of linear regression is to find the best-fitting line (or hyperplane, in the case of multiple features) that minimizes the difference between the actual and predicted values of the target.
The equation for a simple linear regression with one feature is: \begin{equation} y = \beta_0 + \beta_1 x + \epsilon, \end{equation} where:
  • \(y\) - is the predicted target value
  • \(\beta_0\) - is the intercept (the value of \(y\) when \(x=0\))
  • \(\beta_1\) - is the coefficient (weight) of the feature \(x\)
  • \(x\) - is the input feature (independent variable).
  • \(\epsilon\) - is the error term (residuals), which represents the difference between the predicted and actual values
In multiple linear regression the equation is more complicated since it is extended to account for more than one feature: \begin{equation} y = \beta_0 + \beta_1 x_1 + \beta_2x_2 + \cdots + \beta_nx_n + \epsilon \end{equation} Where:
  • \(x_1, x_2,...,x_n\) - are the input features.
  • \(\beta_1, \beta_2,...,\beta_n\) - are the corresponding coefficients.
The objective of the linear regression is to estimate the coefficinets (\(\beta_0, \beta_1,...,\beta_n\)) that best fit the data.

How Does Linear Regression Work?

In linear regression, the relationship between the features and the target variable is modeled by fitting a line (or hyperplane in higher dimensions) to the data. The algorithm uses a method called least squares to find the best-fitting line. Specifically, it minimizes the sum of squared residuals, which is the difference between the observed values and the predicted values: \begin{equation} \mathrm{Sum of Squared Residuals (SSR)} = \sum(y_{true} - y_{pred})^2 \end{equation} This optimization process helps find the best coefficients \(\beta_0, \beta_1,...,\beta_n\) that minimize the SSR, ensuring the model fits the data as well as possible. Once the coefficients are estimated, the model can predict the target variable \(y\) for new, unseen data by plugging the feature values into the equation.

Linear Regression in scikit-learn

Scikit-learn provides a simple and efficient implementation of linear regression through the LinearRegression class. Below, we will walk through an example of how to use scikit-learn to perform linear regression.

Step 1: Import Necessary Libraries

First, we need to import the required libraries. We will use NumPy for handling data arrays and scikit-learn for the linear regression implementation.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

Step 2: Prepare the Data

For this example, let’s use a simple synthetic dataset where we have one feature and one target variable. We can use NumPy to generate some sample data.
# Generate synthetic data
np.random.seed(42)
X = np.random.rand(100, 1) * 10 # 100 data points with one feature
y = 2.5 * X + np.random.randn(100, 1) * 2 # Target variable with some noise
In this case, the true relationship is \(y = 2.5x + \epsilon\), where \(\epsilon\) is some added random noise.

Step 3: Split the Data into Training and Testing Sets

Before training the model, we will split the data into training and testing sets using scikit-learn’s train_test_split function. This ensures that we can evaluate the model on unseen data after training.
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Step 4: Train the Linear Regression Model

Next, we initialize the LinearRegression model and train it on the training data using the fit method.
# Initialize and train the model
model = LinearRegression()
model.fit(X_train, y_train)

Step 5: Make Predictions

Once the model is trained, we can use it to make predictions on the testing set.
# Make predictions on the test set
y_pred = model.predict(X_test)

Step 6: Evaluate the Model

After making predictions, we evaluate the model using common regression metrics, such as Mean Squared Error (MSE) and R-squared (R²).
# Evaluate the model
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"Mean Squared Error: {mse}")
print(f"R-squared: {r2}")
Mean Squared Error (MSE) measures the average squared difference between the actual and predicted values. A lower value indicates better performance. R-squared (R²) is a measure of how well the model explains the variance in the target variable. It ranges from 0 to 1, with 1 indicating perfect prediction.

Visualize the Reslts

To understand how well the model fits the data, let's create the plot that shows xual versus predicted values.
# Plot actual vs predicted values
plt.scatter(X_test, y_test, color='blue', label='Actual')
plt.plot(X_test, y_pred, color='red', label='Predicted', linewidth=2)
plt.xlabel('Feature (X)')
plt.ylabel('Target (y)')
plt.legend()
plt.show()
When we run this code, we get a plot showing the actual vs. predicted values, along with the model’s R-squared score and MSE (Figure 2).
2024-12-31T22:40:15.898066 image/svg+xml Matplotlib v3.8.0, https://matplotlib.org/
Figure 2 - Comparison of actual target values and predicted values by Linear regression
If the model is performing well, the red line (predicted values) should closely follow the blue dots (actual values), and the R-squared score should be high (close to 1). From the results shown in Figure 2 it can be noticed that not all points are lie on the red line (predicted by linear regression model). Some points are far away from predictions made by the model. However, the model caputured the trend of the real data.
The R-squared and the MSE values were 0.95457 and 2.61479, respectively. Both values indicate a good approaximation with linear regression however, there is still room for improvement with use of other machine learning algorithms.

Linear regression is a powerful and easy-to-understand technique for predicting continuous values. Scikit-learn makes it simple to implement linear regression in Python, providing a robust framework for training, evaluating, and making predictions. In this post, we demonstrated how to use scikit-learn's LinearRegression class to perform regression, split data into training and testing sets, train the model, and evaluate its performance. Understanding linear regression is crucial for tackling a variety of real-world regression problems, from predicting prices to forecasting demand. By exploring different regularization techniques like Ridge and Lasso regression, you can extend linear models to handle more complex and high-dimensional datasets, making them even more powerful.

No comments:

Post a Comment