With Python + pandas + matplotlib, we will create a nicely formatted ** heatmap ** from the ** correlation matrix ** (a matrix of correlation coefficients between each variable).
Here, as an example, I would like to create a heat map for the following ** 5 subject grade ** correlation matrix.
We have confirmed the execution and operation with Google Colab. (Python 3.6.9). It's almost the same as Jupyter Notebook.
!pip list
matplotlib 3.1.2
numpy 1.17.4
pandas 0.25.3
Make Japanese available in the output diagram of matplotlib.
!pip install japanize-matplotlib
import japanize_matplotlib
With the above, japanize-matplotlib-1.0.5
will be installed and imported, and even if you use Japanese for labels etc., the characters will not be garbled (tofu).
The correlation matrix can be easily calculated using the pandas function.
import pandas as pd
#Dummy data
National language= [76, 62, 71, 85, 96, 71, 68, 52, 85, 91]
society= [71, 85, 64, 55, 79, 72, 73, 52, 84, 84]
Math= [50, 78, 48, 64, 66, 62, 58, 50, 50, 60]
Science= [37, 90, 45, 56, 59, 56, 84, 86, 51, 61]
English= [59, 97, 71, 85, 58, 82, 70, 61, 79, 70]
df = pd.DataFrame( {'National language':National language, 'society':society, 'Math':Math, 'Science':Science, 'English':English} )
#Calculate the correlation coefficient
df2 = df.corr()
display(df2)
Each element of the matrix takes a value in the range $ -1.0 $ to $ 1.0 $. The closer this value is to $ 1.0 $, the more ** positive the correlation **, and the closer it is to $ -1.0 $, the ** negative correlation **. In the range of $ -0.2 $ to $ 0.2 $, it is judged as ** uncorrelated (uncorrelated) **.
Since the diagonal elements are the correlation coefficient between the same items, it will be $ 1.0 $ (= there is a perfect positive correlation).
Even if the correlation coefficient is arranged as a numerical value as shown above, it is difficult to grasp the whole, so let's visualize it using a heat map.
First of all, let's create a heat map with the minimum necessary code without adjusting the appearance.
%reset -f
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors
#Dummy data
National language= [76, 62, 71, 85, 96, 71, 68, 52, 85, 91]
society= [71, 85, 64, 55, 79, 72, 73, 52, 84, 84]
Math= [50, 78, 48, 64, 66, 62, 58, 50, 50, 60]
Science= [37, 90, 45, 56, 59, 56, 84, 86, 51, 61]
English= [59, 97, 71, 85, 58, 82, 70, 61, 79, 70]
df = pd.DataFrame( {'National language':National language, 'society':society, 'Math':Math, 'Science':Science, 'English':English} )
#Calculate the correlation coefficient
df2 = df.corr()
display(df2)
#Output the matrix of correlation coefficients as a heat map
plt.figure(dpi=120)
plt.imshow(df2,interpolation='nearest',vmin=-1.0,vmax=1.0)
plt.colorbar()
#Setting to output item names (national language, society, mathematics, science, English) on the axis
n = len(df2.columns) #Number of items
plt.gca().set_xticks(range(n))
plt.gca().set_xticklabels(df2.columns)
plt.gca().set_yticks(range(n))
plt.gca().set_yticklabels(df2.columns)
You can get the following output: Based on the color bar on the right side, there is a ** negative correlation ** where the dark purple and blue colors are, and a ** positive correlation where the bright yellow and green colors are. I will read if there is **.
To be honest, you can't create an easy-to-understand heatmap with the default settings.
We will customize it to obtain a beautiful and intuitive heat map. The main points are as follows.
――The diagonal components are white and shaded. --Customize the color map so that it is white in the uncorrelated range. --Insert a grid (draw a white line between squares). --Print the correlation coefficient value on the square. --Border it so that it looks beautiful even if it overlaps with the background color.
When coded, it looks like this:
%reset -f
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.ticker as ticker
import matplotlib.colors
#Dummy data
National language= [76, 62, 71, 85, 96, 71, 68, 52, 85, 91]
society= [71, 85, 64, 55, 79, 72, 73, 52, 84, 84]
Math= [50, 78, 48, 64, 66, 62, 58, 50, 50, 60]
Science= [37, 90, 45, 56, 59, 56, 84, 86, 51, 61]
English= [59, 97, 71, 85, 58, 82, 70, 61, 79, 70]
df = pd.DataFrame( {'National language':National language, 'society':society, 'Math':Math, 'Science':Science, 'English':English} )
#Calculate the correlation coefficient
df2 = df.corr()
for i in df2.index.values :
df2.at[i,i] = 0.0
#Output the matrix of correlation coefficients as a heat map
plt.figure(dpi=120)
#Custom color map
cl = list()
cl.append( ( 0.00, matplotlib.colors.hsv_to_rgb((0.6, 1. ,1))) )
cl.append( ( 0.30, matplotlib.colors.hsv_to_rgb((0.6, 0.1 ,1))) )
cl.append( ( 0.50, matplotlib.colors.hsv_to_rgb((0.3, 0. ,1))) )
cl.append( ( 0.70, matplotlib.colors.hsv_to_rgb((0.0, 0.1 ,1))) )
cl.append( ( 1.00, matplotlib.colors.hsv_to_rgb((0.0, 1. ,1))) )
ccm = matplotlib.colors.LinearSegmentedColormap.from_list('custom_cmap', cl)
plt.imshow(df2,interpolation='nearest',vmin=-1.0,vmax=1.0,cmap=ccm)
#Setting of color bar to be displayed on the left side
fmt = lambda p, pos=None : f'${p:+.1f}$' if p!=0 else ' $0.0$'
cb = plt.colorbar(format=ticker.FuncFormatter(fmt))
cb.set_label('Correlation coefficient', fontsize=11)
#Settings related to output of items (national language, society, mathematics, science, English)
n = len(df2.columns) #Number of items
plt.gca().set_xticks(range(n))
plt.gca().set_xticklabels(df.columns)
plt.gca().set_yticks(range(n))
plt.gca().set_yticklabels(df.columns)
plt.tick_params(axis='x', which='both', direction=None,
top=True, bottom=False, labeltop=True, labelbottom=False)
plt.tick_params(axis='both', which='both', top=False, left=False )
#Grid settings
plt.gca().set_xticks(np.arange(-0.5, n-1), minor=True);
plt.gca().set_yticks(np.arange(-0.5, n-1), minor=True);
plt.grid( which='minor', color='white', linewidth=1)
#Diagonal line
plt.plot([-0.5,n-0.5],[-0.5,n-0.5],color='black',linewidth=0.75)
#Display correlation coefficient (characters have borders)
tp = dict(horizontalalignment='center',verticalalignment='center')
ep = [path_effects.Stroke(linewidth=3, foreground='white'),path_effects.Normal()]
for y,i in enumerate(df2.index.values) :
for x,c in enumerate(df2.columns.values) :
if x != y :
t = plt.text(x, y, f'{df2.at[i,c]:.2f}',**tp)
t.set_path_effects(ep)
Recommended Posts