If you plot a Matplotlib data frame with a hierarchical index as it is, the axis labels will be displayed as tuples, which is a bit disappointing.
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
a = pd.DataFrame(np.random.random([6, 2]),
index=pd.MultiIndex.from_product([['group1', 'group2'],['item1', 'item2', 'item3']]),
columns=['data1', 'data2'])
a.plot.bar()
I wanted to add an axis label that makes it easy to see that the data has a hierarchical structure, but for the time being, it seems that it can be done by the following method.
def set_hierarchical_xlabels(index, ax=None,
bar_xmargin=0.1, #Margins on the left and right ends of the line, X-axis scale
bar_yinterval=0.1, #Relative value with the vertical spacing of the line and the length of the Y axis as 1?
):
from itertools import groupby
from matplotlib.lines import Line2D
ax = ax or plt.gca()
assert isinstance(index, pd.MultiIndex)
labels = ax.set_xticklabels([s for *_, s in index])
for lb in labels:
lb.set_rotation(0)
transform = ax.get_xaxis_transform()
for i in range(1, len(index.codes)):
xpos0 = -0.5 #Coordinates on the left side of the target group
for (*_, code), codes_iter in groupby(zip(*index.codes[:-i])):
xpos1 = xpos0 + sum(1 for _ in codes_iter) #Coordinates on the right side of the target group
ax.text((xpos0+xpos1)/2, (bar_yinterval * (-i-0.1)),
index.levels[-i-1][code],
transform=transform,
ha="center", va="top")
ax.add_line(Line2D([xpos0+bar_xmargin, xpos1-bar_xmargin],
[bar_yinterval * -i]*2,
transform=transform,
color="k", clip_on=False))
xpos0 = xpos1
a.plot.bar()
set_hierarchical_xlabels(a.index)
I think that the detailed design will change depending on the purpose and taste, so it is OK if you modify it appropriately referring to the above. You can also plot with MultiIndex of 3 or more layers.
Recommended Posts