3

我正在尝试为我的数据框中的每个类分配颜色,这是我的代码:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

knn = KNeighborsClassifier(n_neighbors=7)

# fitting the model
knn.fit(X_train, y_train)

# predict the response
pred = knn.predict(X_test)

dfp = pd.DataFrame(X_test)
dfp.columns = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
dfp["PClass"] = pred

pyo.init_notebook_mode()
data = [go.Scatter(x=dfp['SepalLengthCm'], y=dfp['SepalWidthCm'], 
                   text=dfp['PClass'],
                   mode='markers',
                   marker=dict(
                    color=dfp['PClass']))]

layout = go.Layout(title='Chart', hovermode='closest')
fig = go.Figure(data=data, layout=layout)

pyo.iplot(data)

在这里我的 df 看起来像:

SepalLengthCm   SepalWidthCm    PetalLengthCm   PetalWidthCm    PClass
       6.1           2.8             4.7         1.2    Iris-versicolor
      5.7            3.8             1.7         0.3        Iris-setosa
      7.7             2.6        6.9         2.3    Iris-virginica

所以问题是它没有根据dfp['PClass']列分配颜色,并且图上的每个点都是相同的颜色:黑色。即使悬停时每个点都根据其类别正确标记。任何想法为什么它不能正常工作?

4

3 回答 3

2

这是一个使用图形对象的示例:

import numpy as np
import pandas as pd
import plotly.offline as pyo
import plotly.graph_objs as go

# Create some random data
np.random.seed(42)
random_x = np.random.randint(1, 101, 100)
random_y = np.random.randint(1, 101, 100)

# Create two groups for the data
group = []
for letter in range(0,50):
    group.append("A")

for letter in range(0, 50):
    group.append("B")

# Create a dictionary with the three fields to include in the dataframe
group = np.array(group)
data = {
    '1': random_x,
    '2': random_y,
    '3': group
}

# Creat the dataframe
df = pd.DataFrame(data)

# Find the different groups
groups = df['3'].unique()

# Create as many traces as different groups there are and save them in data list
data = []
for group in groups:
    df_group = df[df['3'] == group]
    trace = go.Scatter(x=df_group['1'], 
                        y=df_group['2'],
                        mode='markers',
                        name=group)
    data.append(trace)

# Layout of the plot
layout = go.Layout(title='Grouping')
fig = go.Figure(data=data, layout=layout)

pyo.plot(fig)
于 2021-09-19T15:38:52.950 回答
1

在您的代码示例中,您尝试使用color=dfp['PClass']). 例如,这是一个逻辑ggplot,其中ggplot(mtcars, aes(x=wt, y=mpg, shape=cyl, color=cyl, size=cyl))wherecyl是一个分类变量。您将在页面下方看到一个示例。

但是对于情节来说,这是行不通的。colorin将只接受像本例中这样go.Scatter的数值:color = np.random.randn(500)

在此处输入图像描述

为了达到您想要的结果,您必须使用多个跟踪来构建您的绘图,如下例所示

在此处输入图像描述

于 2019-05-13T07:31:40.633 回答
0

You can do it using plotly express.

import plotly.express as px
fig = px.scatter(dfp, x='SepalLengthCm', y='SepalWidthCm', color='PClass')
于 2021-10-08T20:45:25.917 回答