我正在尝试使用示例数据拟合示例概率图形模型。在模型中拟合数据时,我遇到类型错误。示例代码如下:
import numpy as np
import pandas as pd
from pgmpy.models import BayesianModel
data = np.random.uniform(low=0, high=2, size=(1000, 4)).astype('float')
data
data = pd.DataFrame(data, columns=['cost', 'quality',
'location',
'no_of_people'])
train = data[:750]
test = data[750:].drop('no_of_people', axis=1)
restaurant_model = BayesianModel(
[('location', 'cost'),
('quality', 'cost'),
('location', 'no_of_people'),
('cost', 'no_of_people')])
restaurant_model.fit(train)
我遇到以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-173-8e3a85cb8b56> in <module>()
----> 1 restaurant_model.fit(train)
C:\Users\pranav.waila\AppData\Local\Continuum\Anaconda3\lib\site-packages\pgmpy\models\BayesianModel.py in fit(self, data, estimator_type)
568 estimator = estimator_type(self, data)
569
--> 570 cpds_list = estimator.get_parameters()
571 self.add_cpds(*cpds_list)
572
C:\Users\pranav.waila\AppData\Local\Continuum\Anaconda3\lib\site-packages\pgmpy\estimators\MLE.py in get_parameters(self)
64 state_counts = state_counts.reindex(sorted(state_counts.index))
65 cpd = TabularCPD(node, self.node_card[node],
---> 66 state_counts.values[:, np.newaxis])
67 cpd.normalize()
68 parameters.append(cpd)
C:\Users\pranav.waila\AppData\Local\Continuum\Anaconda3\lib\site-packages\pgmpy\factors\CPD.py in __init__(self, variable, variable_card, values, evidence, evidence_card)
137 raise TypeError("Values must be a 2D list/array")
138
--> 139 super(TabularCPD, self).__init__(variables, cardinality, values.flatten('C'))
140
141 def __repr__(self):
C:\Users\pranav.waila\AppData\Local\Continuum\Anaconda3\lib\site-packages\pgmpy\factors\Factor.py in __init__(self, variables, cardinality, values)
98
99 if values.dtype != int and values.dtype != float:
--> 100 raise TypeError("Values: Expected type int or type float, got ", values.dtype)
101
102 if len(cardinality) != len(variables):
TypeError: ('Values: Expected type int or type float, got ', dtype('int64'))