Data Visualization

The next step in understanding your data is data visualisation.

First, let's load our data

import pandas as pd
data = pd.read_csv('../data/pima-indians-diabetes.csv')
data.head()
Pregnancies Glucose BloodPressure SkinThickness Insulin BMI DiabetesPedigreeFunction Age Outcome
0 6 148 72 35 0 33.6 0.627 50 1
1 1 85 66 29 0 26.6 0.351 31 0
2 8 183 64 0 0 23.3 0.672 32 1
3 1 89 66 23 94 28.1 0.167 21 0
4 0 137 40 35 168 43.1 2.288 33 1

Using Seaborn

For data visualisation we will use a very powerful library, seaborn, that makes it very easy to create beautiful plots. Seaborn is built on top of matplotlib, and adds a lot of beautiful graphs and features making it very easy to plot anything.

Let's import and setup seaborn:

# import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')
%matplotlib inline

Univariate distribution plots

First thing we can do is look at how variables behave separately, on their own.

A distplot() function allows us to plot a distribution of one variable.

It will display a histogram and something called something called "Kernel density estimation" (KDE).

Histogram will group data into several bins(vertical bars), each bin's height showing you the number of data points in it.

KDE is like a smooth continous curve drawn through the top of each bin, representing the distribution of datapoints.

# You can set the number of bins to determine the resolution of the graph.
# If you want to disable kde plot or histogram, you can pass kde=False or hist=False.
sns.distplot(data['Glucose'], bins=20)
<matplotlib.axes._subplots.AxesSubplot at 0x7f54d7ae4a90>

The easiest way to get a quick look at the distribution of all the variables is to use panda's function .hist() like so:

# Setting figsize a little larger so that all the plots would fit in an image
# Semicolon at the end of the command prevents Jupyter Notebook from displaying text output before the image
data.hist(figsize=(12,8));
# f, axes = plt.subplots(2, 2, figsize=(7, 7))
# sns.distplot(data['Glucose'], bins=20, kde=False, ax=axes[0, 0])
# sns.distplot(data['BloodPressure'], bins=20, kde=False, ax=axes[0, 1])

From these plots you can see that things are like Blood Pressure, Glucose levels, and BMI are normally distributed (as you would expect), and Age, Number of Pregnancies, and Insulin have exponential distribution.

Categorial plots

If you want to compare the two categorical variables, for example, how many instances of Diabetes vs Non-Diabetes there are in our dataset, you can use a .countplot():

sns.countplot(x='Outcome', data=data);

You can also create a simple bar plot to compare the average value of each variable by category.

Here, x= should be a categorical variable and y= should be numerical.

Setting an estimator to a custom function(np.std) allows us to plot the standard deviation(vertical dark lines).

sns.barplot(x='Outcome',y='Glucose',data=data, estimator=np.std);

Box plots

Box plots conveniently summarize the distribution of a variable. The line in the middle of the box represents the median value(data point that's in the middle), and box around it shows 25th and 75th percentiles(middle 50% of the data). "Whiskers" show the rest of the distribution, and points outside of them show the outliers.

Let's say we want to compare the glucose levels between people who did and did not get diabetes, we can do it like so:

# x is a categorical variable, y is the value we want to compare
sns.boxplot(x='Outcome',y='Glucose', data=data);

Multivatiate plots

Multivariate plots allow you to compare two variables to each other.

Let's say you want to see if the insulin numbers will rise as the person's blood glucose gets higher:

sns.jointplot(x="Glucose", y='Insulin', data=data);

Often you'll want to compare all the attributes against each other to see if you'll notice any interesting relationships between variables. To do that you can use .pairplot(), it will create a grid of scatter plots, one for each pair of variables in your data.

Our dataset has a lot ot variables, so to avoid generating a huge grid I have passed a vars parameter selecting only the variables I want to compare.

I have also added a hue parameter, allowing us to easily see how the distribution of these variables affects the outcome(whether or not the person gets diabetes).

To avoid comparing each variable to itself(which would be pointless, the main diagonal would be full of straight lines), diagonal shows the histogram of each variable.

sns.pairplot(data=data, vars=["Glucose", "Insulin", "BMI", "Age"], hue='Outcome');

Correllation Matrix

Finally, it is extremely useful to look at correllations between your variables.

Correllation tells you how related the changes between the two variables are.

To do that, we first create a "correlation matrix", that calculates how much variables are correllated with each other, and then plot it as a "heatmap".

The resulting plot allows you to visually understand the correllations between each attribute. High positive or negative values(bright red or blue) mean high correllation.

corr_matrix = data.corr()
sns.heatmap(corr_matrix, cmap='coolwarm')

Conclusion

Data Analysis and visualisation is a very powerful way to understand your data, and is the best place to start working on your machine learning problem. You can use it to see patterns in data, understand how it should be modified, and to pick models that will work best for it.

Receive weekly digest of my best posts!