A Geometric Intuition for Linear Discriminant Analysis

Omar Shehata — St. Olaf College — 2018

Linear Discriminant Analysis, or LDA, is a useful technique in machine learning for classification and dimensionality reduction. It's often used as a preprocessing step since a lot of algorithms perform better on a smaller number of dimensions.

The idea that you can take a rich 10 dimensional space and reduce it to 3 dimensions, while keeping most of the information intact, has always seemed a bit like magic to me. I like to think of LDA as a mathematical tool that allows you to peer into spaces that cannot directly be seen. It's like piercing the veil of the unknowable.

In this article, I'd like to explore the fascinating geometric side of this popular statistical technique. I found it much easier to understand once I got to see it this way, and I hope it will be useful for you as well.

For geometry is the gate of science, and the gate is so low and small that one can only enter it as a child.
— William K. Clifford

The Premise

You're given high dimensional data and you're trying to reduce it. This need not be anything esoteric. It could be data about houses on the market with price, age, size, distance to public transport and number of rooms. That's already 5 dimensions — too many to visualize all at once.

Like any good mathematician, you can start with a simple approach and explore its consequences. What happens if you just drop one of the dimensions?

Let's see what this looks like. Here is some made up data about houses that's only in 2 dimensions. Blue dots could be houses that were sold and red ones are still looking for a buyer.

The graph on the left is the original data, and the line on the right is what it would look like after we drop it 1 dimension by just ignoring the Y axis.

We've successfully reduced our data, but remember that our goal is to reduce it without losing information. In this case, you can easily tell that there is a pattern to the data when looking at the 2D view. This pattern is completely lost in the 1D view.

Another way to think about this is to consider prediction. Given a new unclassified dot in the 2D view, you can easily guess whether it should be blue or red based on whether it's closer to the upper right or bottom left cluster. This would be much harder if you were given just the 1D view. In other words, our prediction algorithms would perform worse on the reduced data.

We've sacrificed accuracy for simplicity, but maybe there's a better way to reduce it. Perhaps we could drop the X axis instead?

You can try that out by clicking and dragging in the 2D graph below to rotate the projection line.

You can click here to see that dropping the X axis isn't much better, but these aren't the only two choices. Geometrically, we have an infinite number of lines we can project on!

Out of all these possible lines, the line y = 0 provides the best possible separation. This best line is what LDA allows us to find.

The reduced data is only 1 dimensional, but it captures the important insight of the higher dimensional view. Think again about prediction. If we're looking at the 1D view, all we have to do to predict whether a new point should be blue or red is just see whether it's on the right or left side.

This means that not only would our prediction algorithms perform just as good on the reduced data, it would be computationally faster (because there's less data to process)! Strangely enough, by removing information, we've made it easier to gain insight.

Higher Dimensions

To build on this intuition, let's go one dimension up and look at an example that we can still visualize in its unreduced form.

Here we've got 3D data, this time with 3 classes, and we want to reduce it to 2D. The diagram below shows one such projection.

If you're using a keyboard, you can rotate the projection plane with W/S, A/D and Q/E.

Again, there's an obvious pattern to the data in its original form that's lost in the reduced view. Finding the best projection plane by hand that retains this pattern is a harder problem now because there's a much greater number of possible projection planes to explore in 3D compared to projection lines in 2D.

For this dataset, this best plane is x + y + z = 0. You can see how this gives us a clean separation in 2D compared to an arbitrary plane like this. Again, this best projection is what LDA allows us to find.

Now here's the real test. What if we had 4D data that we couldn't even visualize? Below is what it looks projected down to 2D, (you can think of it as first projecting into 3D and then to 2D).

  • αXY  =
  • αXZ  =
  • αYZ  =
  • αXW =
  • αYW =
  • αZW =

You can rotate the projection plane in 4D space with the key pairs W/S, A/D, Q/E, J/L and I/K or if you're on a mobile device by clicking on the numbers and typing in an angle.

At first glance, this looks like an entangled mess. But we know that we might be able to find a projection that reveals the higher dimensional structure to us. I say might because we don't actually know what the original data looks like. We can keep rotating all the angles, and we might find something that looks like it has a good separation, but how do we know when we've found the best one?

So far I've been performing LDA behind the scenes to produce the answers. We will now peek behind the curtain and devise the algorithm that will allow us to make such bold claims as this is the best possible separation for this data.

Devising an Algorithm

This is definitely the most exciting part of mathematics in my opinion. It's the creative process of starting with an open-ended question like “How do we look at high dimensional data?” and ending up with a tool that allows us to do just that for any dataset. Rather than spoil the punchline, I'll try to guide you a little further.

The first step is to recognize this as an optimization problem. It's hard to think in the abstract, but we can rely on the intuition we've built up so far. Think back to the simple 2D case. We had a closed set of solutions (we knew the angle of the projection line was between 0 and 360 degrees), and we were trying to pick the best one.

But what exactly does “best” mean? Intuitively it means the best separation between classes in the reduced data. So far we've been judging this visually. To solve this as an optimization problem, we need to come up with a metric. This metric should be a number for any given projection that's high when it has good separation, and low otherwise. If we have that, it becomes a standard optimization problem of finding the maximum that we can easily solve with existing techniques.

So what sort of formula can we come up with that would give us a high value for the projection below:

Compared to a low value for this projection:

The simplest example of such a metric could be to simply compute the difference between the mean of each class in the reduced data. This works fairly well as a measure of how good the separation is in the above cases. But consider the case below, which I've constructed such that the mean of each class is the same as it is in the above dataset.

Since the means of the classes in these two datasets are equal, our metric gives them the same number, which implies that this last separation is as good as the best separation above, which is clearly not true. A badly designed metric like this would lead our optimization algorithm astray.

Can you think of a better way to numerically differentiate a good separation from a bad one? Pretend like it's the 1930's and no known solution for this exists. Give it some thought. What would you come up with?

If you came up with a solution that took into account both the means of the classes as well as the variance (spread), then give yourself a pat on the back. This is exactly what Ronald Fisher came up with 1936.

The formula he came up with is known as Fisher's Linear Discriminant and is defined as:

$$\frac{(\mu_1 - \mu_2)^2}{S_1 + S_2}$$

Where $\mu_i$ is the mean of class $i$. $S_i$ is the “scatter” of class $i$ and is defined as $S_i = \sum_{x\in Class_i} (x - \mu_i)^2$. It's a measure of how spread out the data is (by summing up how far each point is from the mean).

We want to maximize Fisher's Linear Discriminant, which means we want to favor projections that have greater distance between the means (a big numerator) and aren't very spread out (a small denominator).

To see how it works, as always, let's look at a concrete example. Use your mouse or finger in the diagram below to see the computed score for any projection.

$$\frac{(\mu_1 - \mu_2)^2}{S_1 + S_2}$$

$$= \frac{(0.00 - 0.00)^2}{00 + 00}$$

$$= 0$$

Can you see that projections that create better separation get higher scores? More importantly, can you see why?

One thing I discovered by playing around with this diagram is that there's only one way for a projection to have a score of 0, do you see what it is?

Notice how nothing in this formula is specific to the 2D case. It's always the difference of the means divided by the sum of the scatter, whether we're working with 2, 4 or 10 dimensions.

This particular metric is what Fisher chose, but it's certainly not the only possible one. A different metric could optimize for different results (perhaps you care a little more about minimizing scatter so you multiply the denominator by a large number to give it more weight).

I think by understanding that the design of the formulas we use is often a choice with consequences as opposed to just a given absolute, we unlock a power and joy in mathematics.

And that's how LDA works!

Initially created for the Explorable Explanations jam. The source code for this page and all interactive diagrams is available on GitHub and is public domain. The teacher's guide comes with some tips for using it in the classroom, like how to drag and drop your own data into the article.

Special thanks to Professor Matt Richey for his inspiring lectures on LDA in the Algorithms for Decision Making class.

Resources & References
  1. The associated Jupyter notebook shows you how to perform LDA yourself in Python using the data from this article.
  2. For an example with real world data, Julia Silge writes about applying this to Stack Overflow data using PCA (very similar to LDA).
  3. For the mathematical derivation, extending it to more than 2 classes and actually solving the optimization problem, I found these slides and this article very helpful while writing this article.