An Visual Introduction to Machine Learning
In machine learning, computers apply statistical learning techniques to automatically identify patterns in data. These techniques can be used to make highly accurate predictions.
Keep scrolling. Using a data set about homes, we will create a machine learning model to distinguish homes in New York from homes in San Francisco.
scroll
First, some intuition
Let’s say you had to determine whether a home is in San Francisco or in New York. In machine learning terms, categorizing data points is a classification task.
Since San Francisco is relatively hilly, the elevation of a home may be a good way to distinguish the two cities.
Based on the home-elevation data to the right, you could argue that a home above 240 ft should be classified as one in San Francisco.
Adding nuance
Adding another dimension allows for more nuance. For example, New York apartments can be extremely expensive per square foot.
So visualizingVisualizing elevation and price per square foot in a scatterplot helps us distinguish lower-elevation homes.
The data suggests that, among homes at or below 240 ft, those that cost more than $1776 per square foot are in New York City.
Dimensions in a data set are called features, predictors, or variables. 1
Drawing boundaries
You can visualizeThen you could visualize your elevation (>242 ft) and price per square foot (>$1776) observations as the boundaries of regions in youra scatterplot. Homes plotted in the green and blue regions would be in San Francisco and New York, respectively.
Identifying boundaries in data using math is the essence of statistical learning.
Of course, you’ll need additional information to distinguish homes with lower elevations and lower per-square-foot prices.
The dataset we are using to create the model has 7 different dimensions. Creating a model is also known as training a model.
On the right, we are visualizingWe could visualize the variables in a scatterplot matrix to show the relationships between each pair of dimensions.
There are clearlyThis would show patterns in the data, but the boundaries for delineating them are not obvious.
And now, machine learning
Finding patterns in data is where machine learning comes in. Machine learning methods use statistical learning to identify boundaries.
One example of a machine learning method is a decision tree. Decision trees look at one variable at a time and are a reasonably accessible (though rudimentary) machine learning method.
Finding better boundaries
Let's revisit the 240-ft elevation boundary proposed previously to see how we can improve upon our intuition.
Clearly, this requires a different perspective.
By transforming our visualization into a histogram, we cancould better see how frequently homes appear at each elevation.
While the highest home in New York is ~240 ft, the majority of them seem to have far lower elevations.
Your first fork
A decision tree uses if-then statements to define patterns in data.
For example, if a home's elevation is above some number, then the home is probably in San Francisco.
In machine learning, these statements are called forks, and they split the data into two branches based on some value.
That value between the branches is called a split point. Homes to the left of that point get categorized in one way, while those to the right are categorized in another. A split point is the decision tree's version of a boundary.
Tradeoffs
Picking a split point has tradeoffs. Our initial split (~240 ft) incorrectly classifies some San Francisco homes as New York ones.
Look at that large slice of green in the left pie chartFor homes below 240ft there are some not in New York, those are all the San Francisco homes that are misclassified. These are called false negatives.
However, a split point meant to capture every San Francisco home will include many New York homes as well. These are called false positives.
The best split
At the best split, the results of each branch should be as homogeneous (or pure) as possible. There are several mathematical methods you can choose between to calculate the best split.2
As we seediscussed here, even the best split on a single feature does not fully separate the San Francisco homes from the New York ones.
Recursion
To add another split point, the algorithm repeats the process above on the subsets of data. This repetition is called recursion, and it is a concept that appears frequently in training models.3
The histograms to the left show the distribution of each subset, repeated for each variable.
The best split will vary based which branch of the tree you are looking at. 4
For lower elevation homes, price per square foot is, at $1061 per sqft, is the best variable for the next if-then statement. For higher elevation homes, it is price, at $514500 .
Growing a tree
Additional forks will add new information that can increase a tree's prediction accuracy.
Splitting one layer deeper, the tree's accuracy improves to 84%.
Adding several more layers, we get to 96%.
You could even continue to add branches until the tree's predictions are 100% accurate, so that at the end of every branch, the homes are purely in San Francisco or purely in New York.
These ultimate branches of the tree are called leaf nodes. Our decision tree models will classify the homes in each leaf node according to which class of homes is in the majority.
Making predictions
The newly-trained decision tree model determines whether a home is in San Francisco or New York by running each data point through the branches.
Here you can see the dataYou could imagine the data that was used to train the tree flow through the tree.
This data is called training data because it was used to train the model.
Because we grew the tree until it was 100% accurate, this tree maps each training data point perfectly to which city it is in.
Reality check
Of course, what matters more is how the tree performs on previously-unseen data.
To test the tree's performance on new data, we need to apply it to data points that it has never seen before. This previously unused data is called test data.
Ideally, the tree should perform similarly on both known and unknown data.
So this one is less than ideal.5 By running this decision tree on the new data, we find it to have 89.7% accuracy. 12 homes from San Francisco get classified as being in New York, and 13 homes from New York get classified into the San Francisco category.
These errors are due to overfitting. Our model has learned to treat every detail in the training data as important, even details that turned out to be irrelevant.
Overfitting is part of a fundamental concept in machine learning that we’ll explain in our next post.6
Recap
- Machine learning identifies patterns using statistical learning and computers by unearthing boundaries in data sets. You can use it to make predictions.
- One method for making predictions is called a decision tree, which uses a series of if-then statements to identify boundaries and define patterns in the data
- Overfitting happens when some boundaries are based on on distinctions that don't make a difference. You can see if a model overfits by having test data flow through the model.
Original Source of Story
This story is an example copied for the purposes of academic study. It was originally created by r2d3.
After you have answered the question in orange above and fully read the story, please press the button below.
After you have fully read the second portion of this story, please press the button below.
After you have answered the question in orange above and fully read the story, please press the button below.
Footnotes
- Machine learning concepts have arisen across disciplines (computer science, statistics, engineering, psychology, etc), thus the different nomenclature.
- To learn more about calculating the optimal split, search for 'gini index' or 'cross entropy'.
- One reason computers are so good at applying statistical learning techniques is that they're able to do repetitive tasks, very quickly and without getting bored.
- The algorithm described here is greedy, because it takes a top-down approach to splitting the data. In other words, it is looking for the variable that makes each subset the most homogeneous at that moment.
- Hover over the dots to see the path it took in the tree.
- Spoiler alert: It's the bias/variance tradeoff!