When you draw a scatter plot of multiple clusters, the dots may overlap and be difficult to see. Therefore, I created a plot that can be confirmed one cluster at a time using `` `Plotly```.
For example, when there is such data with xy coordinates divided into 5 clusters,
For example, `` `seaborn``` can draw the following plot in one line.
sns.scatterplot(x="x", y="y", hue="class", data=df)
However, it is a little difficult to see if it is left above, so specify the transparency alpha
,
sns.scatterplot(x="x", y="y", hue="class", data=df, alpha=0.5)
Although it has improved a little, it is still difficult to see with this data.
So I thought it would be nice if I could plot the clusters one by one ... and tried using `` `Plotly```.
First, prepare the library,
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly
Prepare dummy data.
x0 = np.random.normal(2, 0.8, 400)
y0 = np.random.normal(2, 0.8, 400)
x1 = np.random.normal(3, 1.2, 600)
y1 = np.random.normal(6, 0.8, 600)
x2 = np.random.normal(4, 0.4, 200)
y2 = np.random.normal(4, 0.8, 200)
x3 = np.random.normal(1, 0.8, 300)
y3 = np.random.normal(3, 1.2, 300)
x4 = np.random.normal(1, 0.8, 300)
y4 = np.random.normal(5, 0.8, 300)
df = pd.DataFrame()
df["x"] = np.concatenate([x0, x1, x2, x3, x4])
df["y"] = np.concatenate([y0, y1, y2, y3, y4])
df["class"] = ["Cluster 0"]*400 + ["Cluster 1"]*600 + ["Cluster 2"]*200+ ["Cluster 3"]*300+ ["Cluster 4"]*300
Next, the plot part of the main subject, I'll show you all the code first.
def plotly_scatterplot(x, y, hue, data, title=""):
cluster = df[hue].unique()
n_cluster = len(cluster)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
fig = go.Figure()
button = []
tf = [True]*n_cluster
tmp = dict(label="all",
method="update",
args=[{"visible": tf}]
)
button.append(tmp)
for i,clu in enumerate(cluster):
fig.add_trace(
go.Scatter(
x = df[df[hue]==clu][x],
y = df[df[hue]==clu][y],
mode="markers",
name=clu,
marker=dict(color=colors[i])
)
)
tf = [False]*n_cluster
tf[i] = True
tmp = dict(label=clu,
method="update",
args=[{"visible": tf}]
)
button.append(tmp)
fig.update_layout(
updatemenus=[
dict(type="buttons",
x=1.15,
y=1,
buttons=button
)
])
x_min = df[x].min()
x_max = df[x].max()
x_range = x_max - x_min
y_min = df[y].min()
y_max = df[y].max()
y_range = y_max - y_min
fig.update_xaxes(range=[x_min-x_range/10, x_max+x_range/10])
fig.update_yaxes(range=[y_min-y_range/10, y_max+x_range/10])
fig.update_layout(
title_text=title,
xaxis_title=x,
yaxis_title=y,
showlegend=False,
)
fig.show()
#plotly.offline.plot(fig, filename='graph.html')
Excuse me for being so long ... There are two points.
fig.add_trace(
go.Scatter(
x = df[df[hue]==clu][x],
y = df[df[hue]==clu][y],
mode="markers",
name=clu,
marker=dict(color=colors[i])
)
)
In this part, we are creating a scatter plot for each cluster in the data frame df
.
colorscontains a color string that is automatically selected by plt. Therefore, it is specified in order with `` `color = colors [i]
.
tf = [False]*n_cluster
tf[i] = True
tmp = dict(label=clu,
method="update",
args=[{"visible": tf}]
)
button.append(tmp)
to tf[False, True, False, False, False]There is a boolean value like,Select which trace to show / hide.
This time, with `` `fig.add_trace```, 5 scatter plots are overlapped, and what number is to be displayed. `` `tf = [True, True, True, True, True" ] `` `and all True will display a scatter plot of all data.
After that, in the next line,
```python
plotly_scatterplot(x="x", y="y", hue="class", data=df, title="Scatter Plot")
You can draw the plot at the beginning.
that's all!
Plotly:Update Button stack overflow:Get default line colour cycle
Recommended Posts