You may want to add text information around the figure. You can add character information with plt.text
, but it is troublesome because you have to specify the location one by one. The following code is added by using legend
without specifying the location.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline
plt.rcParams['font.size'] = 15
def r2(y1, y2):
r2 = str(np.round(np.corrcoef(y1, y2)[0,1],3))
return r2
xa_train = [1,3,5,7]
xa_test = [2,4,6,9]
ya_train = xa_train + np.random.randn(4)
ya_test = xa_test + np.random.randn(4)
xb_train = [1,3,5,7]
xb_test = [2,4,6,9]
yb_train = xb_train + np.random.randn(4)
yb_test = xb_test + np.random.randn(4)
plt.figure()
plt.subplots_adjust(wspace=0.2, hspace=0.4)
gs = gridspec.GridSpec(2, 2, width_ratios=[1,1], height_ratios=[4,1])
plt.subplot(gs[0])
plt.scatter(xa_train, ya_train, color='k', label='train')
plt.scatter(xa_test, ya_test, color='r', label='test')
plt.xlim(0,10)
plt.ylim(0,10)
plt.xticks([0,2,4,6,8,10])
plt.yticks([0,2,4,6,8,10])
plt.plot([0,10],[0,10], color='gray', lw=0.5)
plt.grid()
plt.title('train')
plt.xlabel('measured')
plt.ylabel('predicted')
plt.subplot(gs[1])
plt.scatter(xb_train, yb_train, color='k',label='train')
plt.scatter(xb_test, yb_test, color='r', label='test')
plt.xlim(0,10)
plt.ylim(0,10)
plt.xticks([0,2,4,6,8,10])
plt.yticks([0,2,4,6,8,10])
plt.plot([0,0],[10,10])
plt.plot([0,10],[0,10], color='gray', lw=0.5)
plt.grid()
plt.title('test')
plt.xlabel('measured')
plt.tick_params(left=False,labelleft=False)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
plt.subplot(gs[2])
plt.tick_params(left=False, labelleft=False,bottom=False, labelbottom=False)
for i in ['top','bottom','left','right'] : plt.gca().spines[i].set_visible(False)
plt.scatter(0,0,label='$R^2_{train}=$'+r2(xa_train,ya_train)+'\n$R^2_{test}=$'+r2(xa_test,ya_test),alpha=0)
plt.legend(frameon=False, loc='upper left')
plt.subplot(gs[3])
plt.tick_params(left=False, labelleft=False,bottom=False, labelbottom=False)
for i in ['top','bottom','left','right'] : plt.gca().spines[i].set_visible(False)
plt.scatter(0,0,label='$R^2_{train}=$'+r2(xb_train,yb_train)+'\n$R^2_{test}=$'+r2(xb_test,yb_test),alpha=0)
plt.legend(frameon=False, loc='upper left')
plt.show()
Recommended Posts