Linear regression is one of the simplest and most widely used algorithms in machine learning for predicting continuous numerical values. It assumes a linear relationship between the input features and the target variable, making it easy to understand, interpret, and implement. In this post, we will explain what linear regression is, how it works, and how to implement it using the popular Python library, scikit-learn.
What is Linear Regression?
Imagine you have a basket of chocolates and a piggy bank full of coins. You want to know: We can draw a straight line (Figure 1) assuming a linear relationship between number of coins and the number of chocolates.
- How many chocolates you can get if you have a certain number of coins?
- IF you have more coins, you'll probably get more chocolates!
Imagine we have a set of points on a graph. Each point shows how many conis you have and how many chocolates you can buy. Linear regression helps us find the best straight line that goes through (or near) those points. The equation for that line can be written as: \begin{equation} y = mx + b \end{equation} where:
- \(y\) is the number of chocolates you get,
- \(x\) is the number of coins you have,
- \(m\) is the slope of the line (how much of the number of chocolates changes when you get more coins), and
- \(b\) is the y-intercept (how many chocolates you would get even if you had 0 coins).
The goal of linear regression is to find the best-fitting line (or hyperplane, in the case of multiple features) that minimizes the difference between the actual and predicted values of the target.
The equation for a simple linear regression with one feature is: \begin{equation} y = \beta_0 + \beta_1 x + \epsilon, \end{equation} where:
- \(y\) - is the predicted target value
- \(\beta_0\) - is the intercept (the value of \(y\) when \(x=0\))
- \(\beta_1\) - is the coefficient (weight) of the feature \(x\)
- \(x\) - is the input feature (independent variable).
- \(\epsilon\) - is the error term (residuals), which represents the difference between the predicted and actual values
- \(x_1, x_2,...,x_n\) - are the input features.
- \(\beta_1, \beta_2,...,\beta_n\) - are the corresponding coefficients.
How Does Linear Regression Work?
In linear regression, the relationship between the features and the target variable is modeled by fitting a line (or hyperplane in higher dimensions) to the data. The algorithm uses a method called least squares to find the best-fitting line. Specifically, it minimizes the sum of squared residuals, which is the difference between the observed values and the predicted values:
\begin{equation}
\mathrm{Sum of Squared Residuals (SSR)} = \sum(y_{true} - y_{pred})^2
\end{equation}
This optimization process helps find the best coefficients \(\beta_0, \beta_1,...,\beta_n\) that minimize the SSR, ensuring the model fits the data as well as possible. Once the coefficients are estimated, the model can predict the target variable \(y\) for new, unseen data by plugging the feature values into the equation.
Linear Regression in scikit-learn
Scikit-learn provides a simple and efficient implementation of linear regression through the LinearRegression class. Below, we will walk through an example of how to use scikit-learn to perform linear regression.
Step 1: Import Necessary Libraries
First, we need to import the required libraries. We will use NumPy for handling data arrays and scikit-learn for the linear regression implementation.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
Step 2: Prepare the Data
For this example, let’s use a simple synthetic dataset where we have one feature and one target variable. We can use NumPy to generate some sample data.
# Generate synthetic dataIn this case, the true relationship is \(y = 2.5x + \epsilon\), where \(\epsilon\) is some added random noise.
np.random.seed(42)
X = np.random.rand(100, 1) * 10 # 100 data points with one feature
y = 2.5 * X + np.random.randn(100, 1) * 2 # Target variable with some noise
Step 3: Split the Data into Training and Testing Sets
Before training the model, we will split the data into training and testing sets using scikit-learn’s train_test_split function. This ensures that we can evaluate the model on unseen data after training.
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
Step 4: Train the Linear Regression Model
Next, we initialize the LinearRegression model and train it on the training data using the fit method.
# Initialize and train the model
model = LinearRegression()
model.fit(X_train, y_train)
Step 5: Make Predictions
Once the model is trained, we can use it to make predictions on the testing set.
# Make predictions on the test set
y_pred = model.predict(X_test)
Step 6: Evaluate the Model
After making predictions, we evaluate the model using common regression metrics, such as Mean Squared Error (MSE) and R-squared (R²).
# Evaluate the modelMean Squared Error (MSE) measures the average squared difference between the actual and predicted values. A lower value indicates better performance. R-squared (R²) is a measure of how well the model explains the variance in the target variable. It ranges from 0 to 1, with 1 indicating perfect prediction.
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"Mean Squared Error: {mse}")
print(f"R-squared: {r2}")
Visualize the Reslts
To understand how well the model fits the data, let's create the plot that shows xual versus predicted values.# Plot actual vs predicted valuesWhen we run this code, we get a plot showing the actual vs. predicted values, along with the model’s R-squared score and MSE (Figure 2).
plt.scatter(X_test, y_test, color='blue', label='Actual')
plt.plot(X_test, y_pred, color='red', label='Predicted', linewidth=2)
plt.xlabel('Feature (X)')
plt.ylabel('Target (y)')
plt.legend()
plt.show()
The R-squared and the MSE values were 0.95457 and 2.61479, respectively. Both values indicate a good approaximation with linear regression however, there is still room for improvement with use of other machine learning algorithms.
Linear regression is a powerful and easy-to-understand technique for predicting continuous values. Scikit-learn makes it simple to implement linear regression in Python, providing a robust framework for training, evaluating, and making predictions. In this post, we demonstrated how to use scikit-learn's LinearRegression class to perform regression, split data into training and testing sets, train the model, and evaluate its performance. Understanding linear regression is crucial for tackling a variety of real-world regression problems, from predicting prices to forecasting demand. By exploring different regularization techniques like Ridge and Lasso regression, you can extend linear models to handle more complex and high-dimensional datasets, making them even more powerful.